aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-07 22:41:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 22:45:50 -0700
commit9c161fa8e869dd8e982a0b38ee540f251f52c3ab (patch)
tree16b1763d2f725eb9b40a8f6d0c24b10af9667cd3 /tensorflow/python/distribute
parent35847b960fe59c0e2e5371db55041a62b65dbb37 (diff)
Support independet-worker mode in distribute coordinator.
PiperOrigin-RevId: 207835757
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/BUILD2
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py269
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py336
3 files changed, 501 insertions, 106 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 2bd0b4320a..68d8b8d13b 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -22,7 +22,7 @@ py_library(
py_test(
name = "distribute_coordinator_test",
- size = "small",
+ size = "large",
srcs = ["distribute_coordinator_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index dab1ed43ca..fc9ca4ac4a 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -32,6 +32,23 @@ class _TaskType(object):
WORKER = "worker"
CHIEF = "chief"
EVALUATOR = "evaluator"
+ CLIENT = "client"
+
+
+# TODO(yuefengz): support another mode where the client colocates with one
+# worker.
+class CoordinatorMode(object):
+ """Specify how distribute coordinator runs."""
+ # The default mode where distribute coordinator will run as a standalone
+ # client and connects to remote servers for training. Each remote server can
+ # use the distribute coordinator binary with task_type set correctly which
+ # will then turn into standard servers.
+ SPLIT_CLIENT = 0
+
+ # The distribute coordinator runs on each worker. It will run a standard
+ # server on each worker and optionally run the `worker_fn` that is configured
+ # to talk to its standard server.
+ INDEPENDENT_WORKER = 1
_worker_context = threading.local()
@@ -99,7 +116,6 @@ class _WorkerContext(object):
cluster_spec,
task_type,
task_id,
- between_graph=False,
rpc_layer="grpc",
worker_barrier=None):
"""Initialize the worker context object.
@@ -108,27 +124,15 @@ class _WorkerContext(object):
cluster_spec: a ClusterSpec object. It can be empty or None in the local
training case.
task_type: a string indicating the role of the corresponding task, such as
- "worker" or "ps". It can be None if it is local training or
- `between_graph` is False.
+ "worker" or "ps". It can be None if it is local training or in-graph
+ replicated training.
task_id: an integer indicating id of the corresponding task. It can be
- None if it is local training or `between_graph` is False.
- between_graph: whether it is between-graph replication or not.
+ None if it is local training or in-graph replicated training.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
worker_barrier: optional, the barrier object for worker synchronization.
-
- Raises:
- ValueError: if task_type or task_id is Node or empty and it is distributed
- between-graph replicated training.
"""
- if cluster_spec and between_graph:
- if not task_type or task_id is None:
- raise ValueError("`task_type` and `task_id` must be set in the "
- "distributed between-graph replicated training.")
- if task_type not in cluster_spec.jobs:
- raise ValueError("`task_type` %r not found in the `cluster_spec` %r" %
- (task_type, cluster_spec))
self._cluster_spec = cluster_spec
self._task_type = task_type
self._task_id = task_id
@@ -138,11 +142,16 @@ class _WorkerContext(object):
self._num_workers = _get_num_workers(cluster_spec)
self._is_chief_node = self._is_chief()
+ def _debug_message(self):
+ return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
+ self._cluster_spec, self.task_type, self.task_id)
+
def __enter__(self):
old_context = get_current_worker_context()
if old_context:
raise ValueError(
- "You cannot run distribute coordinator in a `worker_fn`.")
+ "You cannot run distribute coordinator in a `worker_fn`.\t" +
+ self._debug_message())
_worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value,
@@ -159,7 +168,6 @@ class _WorkerContext(object):
# case we use the chief or first worker's master target.
if not self._task_type:
if _TaskType.CHIEF in self._cluster_spec.jobs:
- assert not self.between_graph
task_type = _TaskType.CHIEF
task_id = 0
else:
@@ -177,7 +185,8 @@ class _WorkerContext(object):
def _is_chief(self):
"""Return whether the task is the chief worker."""
- if (not self._cluster_spec or self._task_type in [_TaskType.CHIEF, None]):
+ if (not self._cluster_spec or
+ self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]):
return True
# If not local and chief not in the cluster_spec, use the first worker as
@@ -194,14 +203,19 @@ class _WorkerContext(object):
ValueError: if `worker_barrier` is not passed to the __init__ method.
"""
if not self._worker_barrier:
- raise ValueError(
- "`worker_barrier is not set in the worker context.`")
+ raise ValueError("`worker_barrier is not set in the worker context.` \t" +
+ self._debug_message())
self._worker_barrier.wait()
@property
+ def has_barrier(self):
+ """Whether the barrier is set or not."""
+ return self._worker_barrier is not None
+
+ @property
def distributed_mode(self):
"""Whether it is distributed training or not."""
- return bool(self._cluster_spec)
+ return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR
@property
def cluster_spec(self):
@@ -234,24 +248,110 @@ class _WorkerContext(object):
return self._num_workers
-def _run(worker_fn, cluster_spec, task_type, task_id, between_graph, rpc_layer,
- worker_barrier):
- with _WorkerContext(cluster_spec, task_type, task_id, between_graph,
- rpc_layer, worker_barrier):
+def _run_single_worker(worker_fn,
+ cluster_spec,
+ task_type,
+ task_id,
+ rpc_layer,
+ worker_barrier=None):
+ """Runs a single worker by calling `worker_fn` under context."""
+ with _WorkerContext(
+ cluster_spec,
+ task_type,
+ task_id,
+ rpc_layer=rpc_layer,
+ worker_barrier=worker_barrier):
worker_fn()
+def _run_std_server(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ session_config=None,
+ rpc_layer=None):
+ """Runs a standard server."""
+ server = server_lib.Server(
+ cluster_spec,
+ job_name=task_type,
+ task_index=task_id,
+ config=session_config,
+ protocol=rpc_layer)
+ server.start()
+ return server
+
+
+def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
+ """Runs a standalone client for between-graph replication."""
+ eval_thread = None
+ if _TaskType.EVALUATOR in cluster_spec.jobs:
+ eval_thread = threading.Thread(
+ target=_run_single_worker,
+ args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ kwargs={
+ "rpc_layer": rpc_layer,
+ })
+ eval_thread.start()
+
+ threads = []
+ worker_barrier = _Barrier(_get_num_workers(cluster_spec))
+ for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
+ for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
+ t = threading.Thread(
+ target=_run_single_worker,
+ args=(worker_fn, cluster_spec, task_type, task_id),
+ kwargs={
+ "rpc_layer": rpc_layer,
+ "worker_barrier": worker_barrier
+ })
+ t.start()
+ threads.append(t)
+
+ # TODO(yuefengz): wrap threads into thread coordinator?
+ for t in threads:
+ t.join()
+
+ # TODO(yuefengz): is it necessary to join eval thread?
+ if eval_thread:
+ eval_thread.join()
+
+
+def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer):
+ """Runs a standalone client for in-graph replication."""
+ eval_thread = None
+ if _TaskType.EVALUATOR in cluster_spec.jobs:
+ eval_thread = threading.Thread(
+ target=_run_single_worker,
+ args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ kwargs={
+ "rpc_layer": rpc_layer,
+ })
+ eval_thread.start()
+
+ _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ if eval_thread:
+ eval_thread.join()
+
+
+# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
+# TODO(yuefengz): we may need a smart way to figure out whether the current task
+# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
+ mode=CoordinatorMode.SPLIT_CLIENT,
cluster_spec=None,
+ task_type=None,
+ task_id=None,
between_graph=False,
- rpc_layer=None):
- """Run the coordinator for distributed TensorFlow.
-
- This function runs a unified and split coordinator for distributed TensorFlow.
- Given a `cluster_spec` specifying server addresses and their roles in a
- cluster, this coordinator will figure out how to set them up, give the
- underlying function the right targets for master sessions and coordinate their
- training.
+ rpc_layer="grpc"):
+ """Runs the coordinator for distributed TensorFlow.
+
+ This function runs a split coordinator for distributed TensorFlow in its
+ default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying
+ server addresses and their roles in a cluster, this coordinator will figure
+ out how to set them up, give the underlying function the right targets for
+ master sessions via a scope object and coordinate their training. The cluster
+ consisting of standard servers needs to be brought up either with the standard
+ server binary or with a binary running distribute coordinator with `task_type`
+ set to non-client type which will then turn into standard servers.
In addition to be the distribute coordinator, this is also the source of
configurations for each job in the distributed training. As there are multiple
@@ -261,9 +361,14 @@ def run_distribute_coordinator(worker_fn,
In the between-graph replicated training, this coordinator will create
multiple threads and each calls the `worker_fn` which is supposed to create
- its own graph and connect to one worker master given by its coordinator
- context. In the in-graph replicated training, it has only one thread calling
- this `worker_fn`.
+ its own graph and connect to one worker master given by its context object. In
+ the in-graph replicated training, it has only one thread calling this
+ `worker_fn`.
+
+ Another mode is the INDEPENDENT_WORKER mode where each server runs a
+ distribute coordinator which will start a standard server and optionally runs
+ `worker_fn` depending whether it is between-graph training or in-graph
+ replicated training.
The `worker_fn` defines the training logic and is called under a its own
worker context which can be accessed to via `get_current_worker_context`. A
@@ -274,13 +379,14 @@ def run_distribute_coordinator(worker_fn,
`worker_fn` or to define different environment variables for different
`worker_fn`s.
- The `worker_fn` for the between-graph replication is defined as if there are
- only one worker corresponding to the `worker_fn` and possibly ps jobs. It
- assigns variables to parameter servers and all other operations to that
- worker. In the in-graph replication case, the `worker_fn` has to define
- operations for all worker jobs. Using a distribution strategy can simplify the
- `worker_fn` by not having to worry about the replication and device assignment
- of variables and operations.
+ The `worker_fn` for the between-graph replication is defined as if there is
+ only one worker corresponding to the `worker_fn` and possibly ps jobs. For
+ example, when training with parameter servers, it assigns variables to
+ parameter servers and all other operations to that worker. In the in-graph
+ replication case, the `worker_fn` has to define operations for all worker
+ jobs. Using a distribution strategy can simplify the `worker_fn` by not having
+ to worry about the replication and device assignment of variables and
+ operations.
This method is intended to be invoked by high-level APIs so that users don't
have to explictly call it to run this coordinator. For those who don't use
@@ -309,8 +415,11 @@ def run_distribute_coordinator(worker_fn,
Args:
worker_fn: the function to be called and given the access to a coordinator
context object.
+ mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
+ task_type: the current task type, optional if this is a client.
+ task_id: the current task id, optional if this is a client.
between_graph: a boolean. It is only useful when `cluster_spec` is set and
not empty. If true, it will use between-graph replicated training;
otherwise it will use in-graph replicated training.
@@ -320,9 +429,13 @@ def run_distribute_coordinator(worker_fn,
ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
a ClusterSpec.
"""
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if not cluster_spec:
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
cluster_spec = tf_config.get("cluster", {})
+ task_env = tf_config.get("task", {})
+ if task_env:
+ task_type = task_env.get("type", task_type)
+ task_id = int(task_env.get("index", task_id))
if cluster_spec:
if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
@@ -333,29 +446,45 @@ def run_distribute_coordinator(worker_fn,
"`tf.train.ClusterDef` object")
# TODO(yuefengz): validate cluster_spec.
- threads = []
- if cluster_spec and _TaskType.EVALUATOR in cluster_spec.jobs:
- t = threading.Thread(
- target=_run,
- args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0, between_graph,
- rpc_layer, None))
- t.start()
- threads.append(t)
-
- if cluster_spec and between_graph:
- worker_barrier = _Barrier(_get_num_workers(cluster_spec))
- for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
- for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
- t = threading.Thread(
- target=_run,
- args=(worker_fn, cluster_spec, task_type, task_id, between_graph,
- rpc_layer, worker_barrier))
- t.start()
- threads.append(t)
+ if not cluster_spec:
+ # `mode` is ignored in the local case.
+ _run_single_worker(worker_fn, None, None, None, rpc_layer)
+ elif mode == CoordinatorMode.SPLIT_CLIENT:
+ # The client must know the cluster but servers in the cluster don't have to
+ # know the client.
+ if task_type in [_TaskType.CLIENT, None]:
+ if between_graph:
+ _run_between_graph_client(worker_fn, cluster_spec, rpc_layer)
+ else:
+ _run_in_graph_client(worker_fn, cluster_spec, rpc_layer)
+ else:
+ # If not a client job, run the standard server.
+ server = _run_std_server(
+ cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ server.join()
else:
- # Local or in-graph replicated training.
- _run(worker_fn, cluster_spec, None, None, between_graph, rpc_layer, None)
-
- # TODO(yuefengz): wrapper threads into thread coordinator?
- for t in threads:
- t.join()
+ if mode != CoordinatorMode.INDEPENDENT_WORKER:
+ raise ValueError("Unexpected coordinator mode: %r" % mode)
+
+ # Every one starts a standard server.
+ server = _run_std_server(
+ cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+
+ if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
+ if between_graph:
+ # All jobs run `worker_fn` if between-graph.
+ _run_single_worker(worker_fn, cluster_spec, task_type, task_id,
+ rpc_layer)
+ else:
+ # Only one node runs `worker_fn` if in-graph.
+ context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer)
+ if context.is_chief:
+ _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ else:
+ server.join()
+ elif task_type == _TaskType.EVALUATOR:
+ _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer)
+ else:
+ if task_type != _TaskType.PS:
+ raise ValueError("Unexpected task_type: %r" % task_type)
+ server.join()
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index d7ffeb56a5..319c29ba2f 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -20,9 +20,20 @@ from __future__ import print_function
import contextlib
import copy
+import os
+import sys
import threading
import six
+# pylint: disable=invalid-name
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error:
+ _portpicker_import_error = _error
+ portpicker = None
+# pylint: enable=invalid-name
+
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
@@ -39,6 +50,11 @@ WORKER = distribute_coordinator._TaskType.WORKER
PS = distribute_coordinator._TaskType.PS
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
+SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT
+INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
+
+RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
+
NUM_WORKERS = 3
NUM_PS = 2
@@ -50,7 +66,29 @@ def _bytes_to_str(maybe_bytes):
return str(maybe_bytes, "utf-8")
-class DistributeCoordinatorTest(test.TestCase):
+def _strip_protocol(target):
+ # cluster_spec expects "host:port" strings.
+ if "//" in target:
+ return target.split("//")[1]
+ else:
+ return target
+
+
+class MockServer(object):
+
+ def __init__(self):
+ self._joined = False
+
+ def join(self):
+ assert not self._joined
+ self._joined = True
+
+ @property
+ def joined(self):
+ return self._joined
+
+
+class DistributeCoordinatorTestBase(test.TestCase):
@classmethod
def setUpClass(cls):
@@ -60,14 +98,18 @@ class DistributeCoordinatorTest(test.TestCase):
cls._workers, cls._ps = test_util.create_local_cluster(
NUM_WORKERS, num_ps=NUM_PS)
cls._cluster_spec = {
- WORKER: [_bytes_to_str(w.target) for w in cls._workers],
- PS: [_bytes_to_str(ps.target) for ps in cls._ps]
+ WORKER: [
+ _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
+ ],
+ PS: [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps]
}
def setUp(self):
self._result_correct = 0
self._lock = threading.Lock()
self._worker_context = {}
+ self._std_servers = {}
+ self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
@contextlib.contextmanager
def _test_session(self, target):
@@ -76,6 +118,30 @@ class DistributeCoordinatorTest(test.TestCase):
with session.Session(graph=None, config=config, target=target) as sess:
yield sess
+ def _create_cluster_spec(self,
+ has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
+ if num_workers:
+ cluster_spec[WORKER] = [
+ "localhost:%s" % portpicker.pick_unused_port()
+ for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec[PS] = [
+ "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
+ return cluster_spec
+
def _in_graph_worker_fn(self):
context = distribute_coordinator.get_current_worker_context()
self.assertTrue(context is not None)
@@ -98,13 +164,28 @@ class DistributeCoordinatorTest(test.TestCase):
if result_value == expected:
self._result_correct += 1
- def testInGraph(self):
- """Test it runs in-graph replicated training correctly."""
- distribute_coordinator.run_distribute_coordinator(
- self._in_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=False)
- self.assertEqual(self._result_correct, 1)
+ def _run_coordinator_in_thread(self, worker_fn, **kwargs):
+ t = threading.Thread(
+ target=distribute_coordinator.run_distribute_coordinator,
+ args=(worker_fn,),
+ kwargs=kwargs)
+ t.start()
+ return t
+
+ def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec,
+ **kwargs):
+ threads = {}
+ for task_type in cluster_spec.keys():
+ threads[task_type] = []
+ for task_id in range(len(cluster_spec[task_type])):
+ t = self._run_coordinator_in_thread(
+ worker_fn,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ **kwargs)
+ threads[task_type].append(t)
+ return threads
def _between_graph_worker_fn(self):
context = distribute_coordinator.get_current_worker_context()
@@ -127,13 +208,23 @@ class DistributeCoordinatorTest(test.TestCase):
variables.global_variables_initializer().run()
# Synchronize workers after initializaton.
- context.wait_for_other_workers()
+ if context.has_barrier:
+ context.wait_for_other_workers()
+ else:
+ while True:
+ uninit_vars = sess.run(variables.report_uninitialized_variables())
+ # pylint: disable=g-explicit-length-test
+ if len(uninit_vars) == 0:
+ break
sess.run(train_op)
# Synchronize workers after one step to make sure they all have finished
# training.
- context.wait_for_other_workers()
+ if context.has_barrier:
+ context.wait_for_other_workers()
+ else:
+ self._barrier.wait()
x_val, y_val = sess.run([x, y])
@@ -143,16 +234,6 @@ class DistributeCoordinatorTest(test.TestCase):
with self._lock:
self._result_correct += 1
- def testBetweenGraph(self):
- """Test it runs between-graph replicated training correctly."""
- distribute_coordinator.run_distribute_coordinator(
- self._between_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=True)
-
- # Each finished worker will increment self._result_correct.
- self.assertEqual(self._result_correct, NUM_WORKERS)
-
def _dump_worker_context(self):
"""Dumps the propoerties of each worker context.
@@ -174,6 +255,45 @@ class DistributeCoordinatorTest(test.TestCase):
context.is_chief,
context.distributed_mode)
+ def _run_mock_std_server(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ rpc_layer=None):
+ task_type = str(task_type)
+ task_id = task_id or 0
+ with self._lock:
+ if task_type not in self._std_servers:
+ self._std_servers[task_type] = []
+ while len(self._std_servers[task_type]) <= task_id:
+ self._std_servers[task_type].append(None)
+
+ server = MockServer()
+ self._std_servers[task_type][task_id] = server
+ return server
+
+
+class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
+
+ def testInGraphSplitMode(self):
+ """Test it runs in-graph replication in split client mode."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._in_graph_worker_fn,
+ cluster_spec=self._cluster_spec,
+ between_graph=False)
+ self.assertEqual(self._result_correct, 1)
+
+ def testBetweenGraph(self):
+ """Test it runs between-graph replication in split client mode."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._between_graph_worker_fn,
+ cluster_spec=self._cluster_spec,
+ between_graph=True)
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
def testBetweenGraphContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
@@ -253,15 +373,15 @@ class DistributeCoordinatorTest(test.TestCase):
# and distributed_mode.
self.assertEqual(self._worker_context[CHIEF][0],
("grpc://fake_chief", 4, True, True))
- self.assertEqual(self._worker_context[WORKER][0],
- ("grpc://" + _bytes_to_str(self._workers[0].target),
- NUM_WORKERS + 1, False, True))
- self.assertEqual(self._worker_context[WORKER][1],
- ("grpc://" + _bytes_to_str(self._workers[1].target),
- NUM_WORKERS + 1, False, True))
- self.assertEqual(self._worker_context[WORKER][2],
- ("grpc://" + _bytes_to_str(self._workers[2].target),
- NUM_WORKERS + 1, False, True))
+ self.assertEqual(
+ self._worker_context[WORKER][0],
+ (_bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True))
+ self.assertEqual(
+ self._worker_context[WORKER][1],
+ (_bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True))
+ self.assertEqual(
+ self._worker_context[WORKER][2],
+ (_bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True))
def testInGraphContextWithEval(self):
# Adds a EVALUATOR job.
@@ -272,7 +392,140 @@ class DistributeCoordinatorTest(test.TestCase):
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
cluster_spec=cluster_spec,
- between_graph=False)
+ between_graph=False,
+ rpc_layer=None)
+
+ # There are one "None" task and one EVALUATOR task.
+ self.assertEqual(len(self._worker_context), 2)
+ self.assertTrue("None" in self._worker_context)
+ self.assertTrue(EVALUATOR in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
+ self.assertEqual(len(self._worker_context[EVALUATOR]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(self._worker_context["None"][0], (_strip_protocol(
+ _bytes_to_str(self._workers[0].target)), 3, True, True))
+ self.assertEqual(self._worker_context[EVALUATOR][0],
+ ("fake_evaluator", 3, True, False))
+
+
+class DistributeCoordinatorTestInpendentWorkerMode(
+ DistributeCoordinatorTestBase):
+
+ def testInGraph(self):
+ cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
+ threads = self._run_multiple_coordinator_in_threads(
+ self._in_graph_worker_fn,
+ cluster_spec,
+ between_graph=False,
+ mode=INDEPENDENT_WORKER)
+ threads[WORKER][0].join()
+ self.assertEqual(self._result_correct, 1)
+
+ def testBetweenGraph(self):
+ cluster_spec = self._create_cluster_spec(
+ num_workers=NUM_WORKERS, num_ps=NUM_PS)
+ threads = self._run_multiple_coordinator_in_threads(
+ self._between_graph_worker_fn,
+ cluster_spec,
+ between_graph=True,
+ mode=INDEPENDENT_WORKER)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def testBetweenGraphContext(self):
+ cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
+ # Dumps the task contexts and std server arguments.
+ with test.mock.patch.object(distribute_coordinator, "_run_std_server",
+ self._run_mock_std_server):
+ threads = self._run_multiple_coordinator_in_threads(
+ self._dump_worker_context,
+ cluster_spec,
+ mode=INDEPENDENT_WORKER,
+ between_graph=True,
+ rpc_layer=None)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # There is only one type of task and three such tasks.
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue(WORKER in self._worker_context)
+ self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(
+ self._worker_context[WORKER][0],
+ (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True))
+ self.assertEqual(
+ self._worker_context[WORKER][1],
+ (_bytes_to_str(cluster_spec[WORKER][1]), NUM_WORKERS, False, True))
+ self.assertEqual(
+ self._worker_context[WORKER][2],
+ (_bytes_to_str(cluster_spec[WORKER][2]), NUM_WORKERS, False, True))
+
+ # Make sure each worker runs a std server.
+ self.assertEqual(len(self._std_servers), 1)
+ self.assertTrue(WORKER in self._std_servers)
+ self.assertEqual(len(self._std_servers[WORKER]), 3)
+ self.assertFalse(self._std_servers[WORKER][0].joined)
+ self.assertFalse(self._std_servers[WORKER][1].joined)
+ self.assertFalse(self._std_servers[WORKER][2].joined)
+
+ def testInGraphContext(self):
+ cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
+ # Dumps the task contexts and std server arguments.
+ with test.mock.patch.object(distribute_coordinator, "_run_std_server",
+ self._run_mock_std_server):
+ threads = self._run_multiple_coordinator_in_threads(
+ self._dump_worker_context,
+ cluster_spec,
+ mode=INDEPENDENT_WORKER,
+ between_graph=False,
+ rpc_layer=None)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # There is only a "None" task in the dumped task context.
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue("None" in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(
+ self._worker_context["None"][0],
+ (_bytes_to_str(cluster_spec[WORKER][0]), NUM_WORKERS, True, True))
+
+ # Make sure each worker runs a std server.
+ self.assertEqual(len(self._std_servers), 1)
+ self.assertTrue(WORKER in self._std_servers)
+ self.assertEqual(len(self._std_servers[WORKER]), 3)
+ self.assertFalse(self._std_servers[WORKER][0].joined)
+ self.assertTrue(self._std_servers[WORKER][1].joined)
+ self.assertTrue(self._std_servers[WORKER][2].joined)
+
+ def testInGraphContextWithEval(self):
+ # Adds a EVALUATOR job.
+ cluster_spec = self._create_cluster_spec(
+ num_workers=NUM_WORKERS, has_eval=True)
+
+ # Dumps the task contexts and std server arguments.
+ with test.mock.patch.object(distribute_coordinator, "_run_std_server",
+ self._run_mock_std_server):
+ threads = self._run_multiple_coordinator_in_threads(
+ self._dump_worker_context,
+ cluster_spec,
+ mode=INDEPENDENT_WORKER,
+ between_graph=False,
+ rpc_layer=None)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+ threads[EVALUATOR][0].join()
# There are one "None" task and one EVALUATOR task.
self.assertEqual(len(self._worker_context), 2)
@@ -284,10 +537,23 @@ class DistributeCoordinatorTest(test.TestCase):
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(self._worker_context["None"][0],
- (_bytes_to_str(self._workers[0].target), 3, True, True))
+ (_bytes_to_str(cluster_spec[WORKER][0]), 3, True, True))
self.assertEqual(self._worker_context[EVALUATOR][0],
- ("fake_evaluator", 3, False, True))
+ (cluster_spec[EVALUATOR][0], 3, True, False))
+
+ # Make sure each worker runs a std server.
+ self.assertEqual(len(self._std_servers), 2)
+ self.assertTrue(WORKER in self._std_servers)
+ self.assertTrue(EVALUATOR in self._std_servers)
+ self.assertEqual(len(self._std_servers[WORKER]), 3)
+ self.assertEqual(len(self._std_servers[EVALUATOR]), 1)
+ self.assertFalse(self._std_servers[WORKER][0].joined)
+ self.assertTrue(self._std_servers[WORKER][1].joined)
+ self.assertTrue(self._std_servers[WORKER][2].joined)
+ self.assertFalse(self._std_servers[EVALUATOR][0].joined)
if __name__ == "__main__":
- test.main()
+ # TODO(yuefengz): find a smart way to terminite std server threads.
+ with test.mock.patch.object(sys, "exit", os._exit):
+ test.main()