aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-07-27 14:58:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-27 15:02:45 -0700
commit2beb3a9d8b9df294e7635cc23d195a76fd78de79 (patch)
treec53e0ba8f738093c5e97e5bc2ee6a5b064afcf2b /tensorflow/python/distribute
parent470b43af3153942e4ef838610aa07e93f904fd5f (diff)
Add distribute_coordinator: a unified and split client for distributed traning.
PiperOrigin-RevId: 206378953
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/BUILD40
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py361
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py291
3 files changed, 692 insertions, 0 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
new file mode 100644
index 0000000000..a29043d8b8
--- /dev/null
+++ b/tensorflow/python/distribute/BUILD
@@ -0,0 +1,40 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+py_library(
+ name = "distribute_coordinator",
+ srcs = [
+ "distribute_coordinator.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:training",
+ ],
+)
+
+py_test(
+ name = "distribute_coordinator_test",
+ size = "small",
+ srcs = ["distribute_coordinator_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":distribute_coordinator",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distributed_framework_test_lib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
new file mode 100644
index 0000000000..04c50dbafc
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -0,0 +1,361 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A unified and split coordinator for distributed TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import json
+import os
+import threading
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.training import server_lib
+
+
+class _TaskType(object):
+ PS = "ps"
+ WORKER = "worker"
+ CHIEF = "chief"
+ EVALUATOR = "evaluator"
+
+
+_coordinator_context = threading.local()
+
+
+def get_current_coordinator_context():
+ """Returns the current coordinator context."""
+ try:
+ return _coordinator_context.current
+ except AttributeError:
+ return None
+
+
+class _Barrier(object):
+ """A reusable barrier class for worker synchronization."""
+
+ def __init__(self, num_participants):
+ """Initializes the barrier object.
+
+ Args:
+ num_participants: an integer which is the expected number of calls of
+ `wait` pass to through this barrier.
+ """
+ self._num_participants = num_participants
+ self._counter = 0
+ self._flag = False
+ self._local_sense = threading.local()
+ self._lock = threading.Lock()
+ self._condition = threading.Condition()
+
+ def wait(self):
+ """Waits until all other callers reach the same wait call."""
+ if not hasattr(self._local_sense, "value"):
+ self._local_sense.value = False
+ self._local_sense.value = not self._flag
+ with self._lock:
+ self._counter += 1
+ if self._counter == self._num_participants:
+ self._counter = 0
+ self._flag = self._local_sense.value
+ with self._condition:
+ while self._flag != self._local_sense.value:
+ self._condition.wait()
+ self._condition.notify_all()
+
+
+def _get_num_workers(cluster_spec):
+ """Gets number of workers including chief."""
+ if not cluster_spec:
+ return 0
+ return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len(
+ cluster_spec.as_dict().get(_TaskType.CHIEF, []))
+
+
+class _CoordinatorContext(object):
+ """The coordinator context class.
+
+ This context object provides configuration information for each task. One
+ context manager with a coordinator context object will be created per
+ invocation to the `worker_fn` where `get_current_coordinator_context` can be
+ called to access the coordinator context object.
+ """
+
+ def __init__(self,
+ cluster_spec,
+ task_type,
+ task_id,
+ between_graph=False,
+ rpc_layer="grpc",
+ worker_barrier=None):
+ """Initialize the coordinator context object.
+
+ Args:
+ cluster_spec: a ClusterSpec object. It can be empty or None in the local
+ training case.
+ task_type: a string indicating the role of the corresponding task, such as
+ "worker" or "ps". It can be None if it is local training or
+ `between_graph` is False.
+ task_id: an integer indicating id of the corresponding task. It can be
+ None if it is local training or `between_graph` is False.
+ between_graph: whether it is between-graph replication or not.
+ rpc_layer: optional string specifying the RPC protocol for communication
+ with worker masters. If None or empty, hosts in the `cluster_spec` will
+ be used directly.
+ worker_barrier: optional, the barrier object for worker synchronization.
+
+ Raises:
+ ValueError: if task_type or task_id is Node or empty and it is distributed
+ between-graph replicated training.
+ """
+ if cluster_spec and between_graph:
+ if not task_type or task_id is None:
+ raise ValueError("`task_type` and `task_id` must be set in the "
+ "distributed between-graph replicated training.")
+ if task_type not in cluster_spec.jobs:
+ raise ValueError("`task_type` %r not found in the `cluster_spec` %r" %
+ (task_type, cluster_spec))
+ self._cluster_spec = cluster_spec
+ self._task_type = task_type
+ self._task_id = task_id
+ self._worker_barrier = worker_barrier
+ self._rpc_layer = rpc_layer
+ self._master_target = self._get_master_target()
+ self._num_workers = _get_num_workers(cluster_spec)
+ self._is_chief_node = self._is_chief()
+
+ def __enter__(self):
+ old_context = get_current_coordinator_context()
+ if old_context:
+ raise ValueError(
+ "You cannot run distribute coordinator in a `worker_fn`.")
+ _coordinator_context.current = self
+
+ def __exit__(self, unused_exception_type, unused_exception_value,
+ unused_traceback):
+ _coordinator_context.current = None
+
+ def _get_master_target(self):
+ """Return the master target for a task."""
+ # If cluster_spec is None or empty, we use local master.
+ if not self._cluster_spec:
+ return "local"
+
+ # If task_type is None, then it is in-graph replicated training. In this
+ # case we use the chief or first worker's master target.
+ if not self._task_type:
+ if _TaskType.CHIEF in self._cluster_spec.jobs:
+ assert not self.between_graph
+ task_type = _TaskType.CHIEF
+ task_id = 0
+ else:
+ assert _TaskType.WORKER in self._cluster_spec.jobs
+ task_type = _TaskType.WORKER
+ task_id = 0
+ else:
+ task_type = self._task_type
+ task_id = self._task_id
+
+ prefix = ""
+ if self._rpc_layer:
+ prefix = self._rpc_layer + "://"
+ return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0]
+
+ def _is_chief(self):
+ """Return whether the task is the chief worker."""
+ if (not self._cluster_spec or self._task_type in [_TaskType.CHIEF, None]):
+ return True
+
+ # If not local and chief not in the cluster_spec, use the first worker as
+ # chief.
+ if (_TaskType.CHIEF not in self._cluster_spec.jobs and
+ self._task_type == _TaskType.WORKER and self._task_id == 0):
+ return True
+ return False
+
+ def wait_for_other_workers(self):
+ """Waits for other workers to reach the same call to this method.
+
+ Raises:
+ ValueError: if `worker_barrier` is not passed to the __init__ method.
+ """
+ if not self._worker_barrier:
+ raise ValueError(
+ "`worker_barrier is not set in the coordinator context.`")
+ self._worker_barrier.wait()
+
+ @property
+ def distributed_mode(self):
+ """Whether it is distributed training or not."""
+ return bool(self._cluster_spec)
+
+ @property
+ def cluster_spec(self):
+ """Returns a copy of the cluster_spec object."""
+ return copy.deepcopy(self._cluster_spec)
+
+ @property
+ def task_type(self):
+ """Returns the role of the corresponing task."""
+ return self._task_type
+
+ @property
+ def task_id(self):
+ """Returns the id or index of the corresponing task."""
+ return self._task_id
+
+ @property
+ def master_target(self):
+ """Returns the session master for the corresponding task to connect to."""
+ return self._master_target
+
+ @property
+ def is_chief(self):
+ """Returns whether the task is a chief node."""
+ return self._is_chief_node
+
+ @property
+ def num_workers(self):
+ """Returns number of workers in the cluster, including chief."""
+ return self._num_workers
+
+
+def _run(worker_fn, cluster_spec, task_type, task_id, between_graph, rpc_layer,
+ worker_barrier):
+ with _CoordinatorContext(cluster_spec, task_type, task_id, between_graph,
+ rpc_layer, worker_barrier):
+ worker_fn()
+
+
+def run_distribute_coordinator(worker_fn,
+ cluster_spec=None,
+ between_graph=False,
+ rpc_layer=None):
+ """Run the coordinator for distributed TensorFlow.
+
+ This function runs a unified and split coordinator for distributed TensorFlow.
+ Given a `cluster_spec` specifying server addresses and their roles in a
+ cluster, this coordinator will figure out how to set them up, give the
+ underlying function the right targets for master sessions and coordinate their
+ training.
+
+ In addition to be the distribute coordinator, this is also the source of
+ configurations for each job in the distributed training. As there are multiple
+ ways to configure a distributed TensorFlow cluster, its context object
+ provides these configurations so that users or higher-level APIs don't have to
+ figure out the configuration for each job by themselves.
+
+ In the between-graph replicated training, this coordinator will create
+ multiple threads and each calls the `worker_fn` which is supposed to create
+ its own graph and connect to one worker master given by its coordinator
+ context. In the in-graph replicated training, it has only one thread calling
+ this `worker_fn`.
+
+ The `worker_fn` defines the training logic and is called under a its own
+ coordinator context which can be accessed to via
+ `get_current_coordinator_context`. A coordinator context provides access to
+ configurations for each task, e.g. the task_type, task_id, master target and
+ so on. Since `worker_fn` will be called in a thread and possibly multiple
+ times, caller should be careful when it accesses global data. For example, it
+ is unsafe to define flags in a `worker_fn` or to define different environment
+ variables for different `worker_fn`s.
+
+ The `worker_fn` for the between-graph replication is defined as if there are
+ only one worker corresponding to the `worker_fn` and possibly ps jobs. It
+ assigns variables to parameter servers and all other operations to that
+ worker. In the in-graph replication case, the `worker_fn` has to define
+ operations for all worker jobs. Using a distribution strategy can simplify the
+ `worker_fn` by not having to worry about the replication and device assignment
+ of variables and operations.
+
+ This method is intended to be invoked by high-level APIs so that users don't
+ have to explictly call it to run this coordinator. For those who don't use
+ high-level APIs, to change a program to use this coordinator, wrap everything
+ in a the program after global data definitions such as commandline flag
+ definition into the `worker_fn` and get task-specific configurations from
+ the coordinator context.
+
+ The `cluster_spec` can be either passed by the argument or parsed from the
+ "TF_CONFIG" envrionment variable. Example of a TF_CONFIG:
+ ```
+ cluster = {'chief': ['host0:2222'],
+ 'ps': ['host1:2222', 'host2:2222'],
+ 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
+ os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
+ ```
+
+ If `cluster_spec` is not given in any format, it becomes local training and
+ this coordinator will connect to a local session.
+
+ For evaluation, if "evaluator" exist in the cluster_spec, a separate thread
+ will be created with its `task_type` set to "evaluator". If "evaluator" is not
+ set in the cluster_spec, it entirely depends on the `worker_fn` for how to do
+ evaluation.
+
+ Args:
+ worker_fn: the function to be called and given the access to a coordinator
+ context object.
+ cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
+ in a cluster. If not set or empty, fall back to local training.
+ between_graph: a boolean. It is only useful when `cluster_spec` is set and
+ not empty. If true, it will use between-graph replicated training;
+ otherwise it will use in-graph replicated training.
+ rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
+
+ Raises:
+ ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
+ a ClusterSpec.
+ """
+ if not cluster_spec:
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ cluster_spec = tf_config.get("cluster", {})
+
+ if cluster_spec:
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ cluster_spec = server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ # TODO(yuefengz): validate cluster_spec.
+
+ threads = []
+ if cluster_spec and _TaskType.EVALUATOR in cluster_spec.jobs:
+ t = threading.Thread(
+ target=_run,
+ args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0, between_graph,
+ rpc_layer, None))
+ t.start()
+ threads.append(t)
+
+ if cluster_spec and between_graph:
+ worker_barrier = _Barrier(_get_num_workers(cluster_spec))
+ for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
+ for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
+ t = threading.Thread(
+ target=_run,
+ args=(worker_fn, cluster_spec, task_type, task_id, between_graph,
+ rpc_layer, worker_barrier))
+ t.start()
+ threads.append(t)
+ else:
+ # Local or in-graph replicated training.
+ _run(worker_fn, cluster_spec, None, None, between_graph, rpc_layer, None)
+
+ # TODO(yuefengz): wrapper threads into thread coordinator?
+ for t in threads:
+ t.join()
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
new file mode 100644
index 0000000000..82fd823352
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -0,0 +1,291 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for distribute coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import copy
+import threading
+import six
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.distribute import distribute_coordinator
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+CHIEF = distribute_coordinator._TaskType.CHIEF
+WORKER = distribute_coordinator._TaskType.WORKER
+PS = distribute_coordinator._TaskType.PS
+EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
+
+NUM_WORKERS = 3
+NUM_PS = 2
+
+
+def _bytes_to_str(maybe_bytes):
+ if isinstance(maybe_bytes, six.string_types):
+ return maybe_bytes
+ else:
+ return str(maybe_bytes, "utf-8")
+
+
+class DistributeCoordinatorTest(test.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # We have to create a global in-process cluster because once an in-process
+ # tensorflow server is created, there is no way to terminate it. Please see
+ # multi_worker_test_base.py for more details.
+ cls._workers, cls._ps = test_util.create_local_cluster(
+ NUM_WORKERS, num_ps=NUM_PS)
+ cls._cluster_spec = {
+ WORKER: [_bytes_to_str(w.target) for w in cls._workers],
+ PS: [_bytes_to_str(ps.target) for ps in cls._ps]
+ }
+
+ def setUp(self):
+ self._result_correct = 0
+ self._lock = threading.Lock()
+ self._task_context = {}
+
+ @contextlib.contextmanager
+ def _test_session(self, target):
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ config.graph_options.optimizer_options.opt_level = -1
+ with session.Session(graph=None, config=config, target=target) as sess:
+ yield sess
+
+ def _in_graph_worker_fn(self):
+ context = distribute_coordinator.get_current_coordinator_context()
+ self.assertTrue(context is not None)
+ with self._test_session(target=context.master_target) as sess:
+ xs = []
+ expected = 0.0
+ for i in range(context.num_workers):
+ with ops.device("/job:worker/task:%d" % i):
+ x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
+ x_add = x.assign_add(float(i))
+ xs.append(x_add)
+ expected += i + 10.0
+
+ with ops.device("/job:worker/task:0"):
+ result = math_ops.add_n(xs)
+
+ variables.global_variables_initializer().run()
+ result_value = sess.run(result)
+ self.assertEqual(result_value, expected)
+ if result_value == expected:
+ self._result_correct += 1
+
+ def testInGraph(self):
+ """Test it runs in-graph replicated training correctly."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._in_graph_worker_fn,
+ cluster_spec=self._cluster_spec,
+ between_graph=False)
+ self.assertEqual(self._result_correct, 1)
+
+ def _between_graph_worker_fn(self):
+ context = distribute_coordinator.get_current_coordinator_context()
+ self.assertTrue(context is not None)
+ with self._test_session(target=context.master_target) as sess:
+ with ops.device("/job:ps/task:0"):
+ # TODO(yuefengz): investigate why not using resource variable will make
+ # the test flaky.
+ x = variable_scope.get_variable(
+ "x", initializer=10.0, use_resource=True)
+ with ops.device("/job:ps/task:1"):
+ y = variable_scope.get_variable(
+ "y", initializer=20.0, use_resource=True)
+
+ x_add = x.assign_add(2.0)
+ y_sub = y.assign_sub(2.0)
+ train_op = control_flow_ops.group([x_add, y_sub])
+
+ if context.is_chief:
+ variables.global_variables_initializer().run()
+
+ # Synchronize workers after initializaton.
+ context.wait_for_other_workers()
+
+ sess.run(train_op)
+
+ # Synchronize workers after one step to make sure they all have finished
+ # training.
+ context.wait_for_other_workers()
+
+ x_val, y_val = sess.run([x, y])
+
+ self.assertEqual(x_val, 16.0)
+ self.assertEqual(y_val, 14.0)
+ if x_val == 16.0 and y_val == 14.0:
+ with self._lock:
+ self._result_correct += 1
+
+ def testBetweenGraph(self):
+ """Test it runs between-graph replicated training correctly."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._between_graph_worker_fn,
+ cluster_spec=self._cluster_spec,
+ between_graph=True)
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def _dump_task_context(self):
+ """Dumps the propoerties of each coordinator context.
+
+ It dumps the context properties to a dict mapping from task_type to a list
+ of tuples of master_target, num_workers, is_chief and distribute_mode, where
+ the list is indexed by the task_id.
+ """
+ context = distribute_coordinator.get_current_coordinator_context()
+ self.assertTrue(context is not None)
+ task_type = str(context.task_type)
+ task_id = context.task_id or 0
+ with self._lock:
+ if task_type not in self._task_context:
+ self._task_context[task_type] = []
+ while len(self._task_context[task_type]) <= task_id:
+ self._task_context[task_type].append(None)
+ self._task_context[task_type][task_id] = (context.master_target,
+ context.num_workers,
+ context.is_chief,
+ context.distributed_mode)
+
+ def testBetweenGraphContext(self):
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context,
+ cluster_spec=self._cluster_spec,
+ between_graph=True)
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._task_context), 1)
+ self.assertTrue(WORKER in self._task_context)
+ self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(
+ self._task_context[WORKER][0],
+ (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
+ self.assertEqual(
+ self._task_context[WORKER][1],
+ (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
+ self.assertEqual(
+ self._task_context[WORKER][2],
+ (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
+
+ def testInGraphContext(self):
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context,
+ cluster_spec=self._cluster_spec,
+ between_graph=False)
+
+ # There is only a "None" task in the dumped task context.
+ self.assertEqual(len(self._task_context), 1)
+ self.assertTrue("None" in self._task_context)
+ self.assertEqual(len(self._task_context["None"]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(
+ self._task_context["None"][0],
+ (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
+
+ def testLocalContext(self):
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context, cluster_spec=None, between_graph=True)
+
+ # There is only a "None" task.
+ self.assertEqual(len(self._task_context), 1)
+ self.assertTrue("None" in self._task_context)
+ self.assertEqual(len(self._task_context["None"]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(self._task_context["None"][0], ("local", 0, True, False))
+
+ def testBetweenGraphContextWithChief(self):
+ # Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
+ cluster_spec = copy.deepcopy(self._cluster_spec)
+ cluster_spec[CHIEF] = ["fake_chief"]
+
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context,
+ cluster_spec=cluster_spec,
+ between_graph=True,
+ rpc_layer="grpc")
+
+ # There are one CHIEF and three workers.
+ self.assertEqual(len(self._task_context), 2)
+ self.assertTrue(CHIEF in self._task_context)
+ self.assertTrue(WORKER in self._task_context)
+ self.assertEqual(len(self._task_context[CHIEF]), 1)
+ self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(self._task_context[CHIEF][0],
+ ("grpc://fake_chief", 4, True, True))
+ self.assertEqual(self._task_context[WORKER][0],
+ ("grpc://" + _bytes_to_str(self._workers[0].target),
+ NUM_WORKERS + 1, False, True))
+ self.assertEqual(self._task_context[WORKER][1],
+ ("grpc://" + _bytes_to_str(self._workers[1].target),
+ NUM_WORKERS + 1, False, True))
+ self.assertEqual(self._task_context[WORKER][2],
+ ("grpc://" + _bytes_to_str(self._workers[2].target),
+ NUM_WORKERS + 1, False, True))
+
+ def testInGraphContextWithEval(self):
+ # Adds a EVALUATOR job.
+ cluster_spec = copy.deepcopy(self._cluster_spec)
+ cluster_spec[EVALUATOR] = ["fake_evaluator"]
+
+ # Dumps the task contexts to the self._task_context dict.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_task_context, cluster_spec=cluster_spec, between_graph=False)
+
+ # There are one "None" task and one EVALUATOR task.
+ self.assertEqual(len(self._task_context), 2)
+ self.assertTrue("None" in self._task_context)
+ self.assertTrue(EVALUATOR in self._task_context)
+ self.assertEqual(len(self._task_context["None"]), 1)
+ self.assertEqual(len(self._task_context[EVALUATOR]), 1)
+
+ # Check whether each task has the right master_target, num_workers, is_chief
+ # and distributed_mode.
+ self.assertEqual(self._task_context["None"][0],
+ (_bytes_to_str(self._workers[0].target), 3, True, True))
+ self.assertEqual(self._task_context[EVALUATOR][0],
+ ("fake_evaluator", 3, False, True))
+
+
+if __name__ == "__main__":
+ test.main()