diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-07-31 22:40:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 22:45:03 -0700 |
commit | 209efdd2bcc4d7f0d5aeda77fb11f04c37d484e6 (patch) | |
tree | 4ecd0b71391e91feedf85c04d681b466b61f5d45 /tensorflow/python/distribute | |
parent | 7ca6ee15555db77c09861fc7e84e5181001da07d (diff) |
Rename coordinator context to worker context.
PiperOrigin-RevId: 206882388
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator.py | 46 | ||||
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator_test.py | 110 |
2 files changed, 79 insertions, 77 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index 04c50dbafc..b5b4b28033 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -34,13 +34,13 @@ class _TaskType(object): EVALUATOR = "evaluator" -_coordinator_context = threading.local() +_worker_context = threading.local() -def get_current_coordinator_context(): - """Returns the current coordinator context.""" +def get_current_worker_context(): + """Returns the current task context.""" try: - return _coordinator_context.current + return _worker_context.current except AttributeError: return None @@ -86,13 +86,13 @@ def _get_num_workers(cluster_spec): cluster_spec.as_dict().get(_TaskType.CHIEF, [])) -class _CoordinatorContext(object): - """The coordinator context class. +class _WorkerContext(object): + """The worker context class. This context object provides configuration information for each task. One - context manager with a coordinator context object will be created per - invocation to the `worker_fn` where `get_current_coordinator_context` can be - called to access the coordinator context object. + context manager with a worker context object will be created per + invocation to the `worker_fn` where `get_current_worker_context` can be called + to access the worker context object. """ def __init__(self, @@ -102,7 +102,7 @@ class _CoordinatorContext(object): between_graph=False, rpc_layer="grpc", worker_barrier=None): - """Initialize the coordinator context object. + """Initialize the worker context object. Args: cluster_spec: a ClusterSpec object. It can be empty or None in the local @@ -139,15 +139,15 @@ class _CoordinatorContext(object): self._is_chief_node = self._is_chief() def __enter__(self): - old_context = get_current_coordinator_context() + old_context = get_current_worker_context() if old_context: raise ValueError( "You cannot run distribute coordinator in a `worker_fn`.") - _coordinator_context.current = self + _worker_context.current = self def __exit__(self, unused_exception_type, unused_exception_value, unused_traceback): - _coordinator_context.current = None + _worker_context.current = None def _get_master_target(self): """Return the master target for a task.""" @@ -195,7 +195,7 @@ class _CoordinatorContext(object): """ if not self._worker_barrier: raise ValueError( - "`worker_barrier is not set in the coordinator context.`") + "`worker_barrier is not set in the worker context.`") self._worker_barrier.wait() @property @@ -236,7 +236,7 @@ class _CoordinatorContext(object): def _run(worker_fn, cluster_spec, task_type, task_id, between_graph, rpc_layer, worker_barrier): - with _CoordinatorContext(cluster_spec, task_type, task_id, between_graph, + with _WorkerContext(cluster_spec, task_type, task_id, between_graph, rpc_layer, worker_barrier): worker_fn() @@ -266,13 +266,13 @@ def run_distribute_coordinator(worker_fn, this `worker_fn`. The `worker_fn` defines the training logic and is called under a its own - coordinator context which can be accessed to via - `get_current_coordinator_context`. A coordinator context provides access to - configurations for each task, e.g. the task_type, task_id, master target and - so on. Since `worker_fn` will be called in a thread and possibly multiple - times, caller should be careful when it accesses global data. For example, it - is unsafe to define flags in a `worker_fn` or to define different environment - variables for different `worker_fn`s. + worker context which can be accessed to via `get_current_worker_context`. A + worker context provides access to configurations for each task, e.g. the + task_type, task_id, master target and so on. Since `worker_fn` will be called + in a thread and possibly multiple times, caller should be careful when it + accesses global data. For example, it is unsafe to define flags in a + `worker_fn` or to define different environment variables for different + `worker_fn`s. The `worker_fn` for the between-graph replication is defined as if there are only one worker corresponding to the `worker_fn` and possibly ps jobs. It @@ -287,7 +287,7 @@ def run_distribute_coordinator(worker_fn, high-level APIs, to change a program to use this coordinator, wrap everything in a the program after global data definitions such as commandline flag definition into the `worker_fn` and get task-specific configurations from - the coordinator context. + the worker context. The `cluster_spec` can be either passed by the argument or parsed from the "TF_CONFIG" envrionment variable. Example of a TF_CONFIG: diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py index 82fd823352..d7ffeb56a5 100644 --- a/tensorflow/python/distribute/distribute_coordinator_test.py +++ b/tensorflow/python/distribute/distribute_coordinator_test.py @@ -67,7 +67,7 @@ class DistributeCoordinatorTest(test.TestCase): def setUp(self): self._result_correct = 0 self._lock = threading.Lock() - self._task_context = {} + self._worker_context = {} @contextlib.contextmanager def _test_session(self, target): @@ -77,7 +77,7 @@ class DistributeCoordinatorTest(test.TestCase): yield sess def _in_graph_worker_fn(self): - context = distribute_coordinator.get_current_coordinator_context() + context = distribute_coordinator.get_current_worker_context() self.assertTrue(context is not None) with self._test_session(target=context.master_target) as sess: xs = [] @@ -107,7 +107,7 @@ class DistributeCoordinatorTest(test.TestCase): self.assertEqual(self._result_correct, 1) def _between_graph_worker_fn(self): - context = distribute_coordinator.get_current_coordinator_context() + context = distribute_coordinator.get_current_worker_context() self.assertTrue(context is not None) with self._test_session(target=context.master_target) as sess: with ops.device("/job:ps/task:0"): @@ -153,113 +153,113 @@ class DistributeCoordinatorTest(test.TestCase): # Each finished worker will increment self._result_correct. self.assertEqual(self._result_correct, NUM_WORKERS) - def _dump_task_context(self): - """Dumps the propoerties of each coordinator context. + def _dump_worker_context(self): + """Dumps the propoerties of each worker context. It dumps the context properties to a dict mapping from task_type to a list of tuples of master_target, num_workers, is_chief and distribute_mode, where the list is indexed by the task_id. """ - context = distribute_coordinator.get_current_coordinator_context() + context = distribute_coordinator.get_current_worker_context() self.assertTrue(context is not None) task_type = str(context.task_type) task_id = context.task_id or 0 with self._lock: - if task_type not in self._task_context: - self._task_context[task_type] = [] - while len(self._task_context[task_type]) <= task_id: - self._task_context[task_type].append(None) - self._task_context[task_type][task_id] = (context.master_target, - context.num_workers, - context.is_chief, - context.distributed_mode) + if task_type not in self._worker_context: + self._worker_context[task_type] = [] + while len(self._worker_context[task_type]) <= task_id: + self._worker_context[task_type].append(None) + self._worker_context[task_type][task_id] = (context.master_target, + context.num_workers, + context.is_chief, + context.distributed_mode) def testBetweenGraphContext(self): - # Dumps the task contexts to the self._task_context dict. + # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( - self._dump_task_context, + self._dump_worker_context, cluster_spec=self._cluster_spec, between_graph=True) # There is only one type of task and there three such tasks. - self.assertEqual(len(self._task_context), 1) - self.assertTrue(WORKER in self._task_context) - self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS) + self.assertEqual(len(self._worker_context), 1) + self.assertTrue(WORKER in self._worker_context) + self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. self.assertEqual( - self._task_context[WORKER][0], + self._worker_context[WORKER][0], (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) self.assertEqual( - self._task_context[WORKER][1], + self._worker_context[WORKER][1], (_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True)) self.assertEqual( - self._task_context[WORKER][2], + self._worker_context[WORKER][2], (_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True)) def testInGraphContext(self): - # Dumps the task contexts to the self._task_context dict. + # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( - self._dump_task_context, + self._dump_worker_context, cluster_spec=self._cluster_spec, between_graph=False) # There is only a "None" task in the dumped task context. - self.assertEqual(len(self._task_context), 1) - self.assertTrue("None" in self._task_context) - self.assertEqual(len(self._task_context["None"]), 1) + self.assertEqual(len(self._worker_context), 1) + self.assertTrue("None" in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. self.assertEqual( - self._task_context["None"][0], + self._worker_context["None"][0], (_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True)) def testLocalContext(self): - # Dumps the task contexts to the self._task_context dict. + # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( - self._dump_task_context, cluster_spec=None, between_graph=True) + self._dump_worker_context, cluster_spec=None, between_graph=True) # There is only a "None" task. - self.assertEqual(len(self._task_context), 1) - self.assertTrue("None" in self._task_context) - self.assertEqual(len(self._task_context["None"]), 1) + self.assertEqual(len(self._worker_context), 1) + self.assertTrue("None" in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. - self.assertEqual(self._task_context["None"][0], ("local", 0, True, False)) + self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False)) def testBetweenGraphContextWithChief(self): # Adds a chief node, so there are NUM_WORKERS + 1 workers in total. cluster_spec = copy.deepcopy(self._cluster_spec) cluster_spec[CHIEF] = ["fake_chief"] - # Dumps the task contexts to the self._task_context dict. + # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( - self._dump_task_context, + self._dump_worker_context, cluster_spec=cluster_spec, between_graph=True, rpc_layer="grpc") # There are one CHIEF and three workers. - self.assertEqual(len(self._task_context), 2) - self.assertTrue(CHIEF in self._task_context) - self.assertTrue(WORKER in self._task_context) - self.assertEqual(len(self._task_context[CHIEF]), 1) - self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS) + self.assertEqual(len(self._worker_context), 2) + self.assertTrue(CHIEF in self._worker_context) + self.assertTrue(WORKER in self._worker_context) + self.assertEqual(len(self._worker_context[CHIEF]), 1) + self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. - self.assertEqual(self._task_context[CHIEF][0], + self.assertEqual(self._worker_context[CHIEF][0], ("grpc://fake_chief", 4, True, True)) - self.assertEqual(self._task_context[WORKER][0], + self.assertEqual(self._worker_context[WORKER][0], ("grpc://" + _bytes_to_str(self._workers[0].target), NUM_WORKERS + 1, False, True)) - self.assertEqual(self._task_context[WORKER][1], + self.assertEqual(self._worker_context[WORKER][1], ("grpc://" + _bytes_to_str(self._workers[1].target), NUM_WORKERS + 1, False, True)) - self.assertEqual(self._task_context[WORKER][2], + self.assertEqual(self._worker_context[WORKER][2], ("grpc://" + _bytes_to_str(self._workers[2].target), NUM_WORKERS + 1, False, True)) @@ -268,22 +268,24 @@ class DistributeCoordinatorTest(test.TestCase): cluster_spec = copy.deepcopy(self._cluster_spec) cluster_spec[EVALUATOR] = ["fake_evaluator"] - # Dumps the task contexts to the self._task_context dict. + # Dumps the task contexts to the self._worker_context dict. distribute_coordinator.run_distribute_coordinator( - self._dump_task_context, cluster_spec=cluster_spec, between_graph=False) + self._dump_worker_context, + cluster_spec=cluster_spec, + between_graph=False) # There are one "None" task and one EVALUATOR task. - self.assertEqual(len(self._task_context), 2) - self.assertTrue("None" in self._task_context) - self.assertTrue(EVALUATOR in self._task_context) - self.assertEqual(len(self._task_context["None"]), 1) - self.assertEqual(len(self._task_context[EVALUATOR]), 1) + self.assertEqual(len(self._worker_context), 2) + self.assertTrue("None" in self._worker_context) + self.assertTrue(EVALUATOR in self._worker_context) + self.assertEqual(len(self._worker_context["None"]), 1) + self.assertEqual(len(self._worker_context[EVALUATOR]), 1) # Check whether each task has the right master_target, num_workers, is_chief # and distributed_mode. - self.assertEqual(self._task_context["None"][0], + self.assertEqual(self._worker_context["None"][0], (_bytes_to_str(self._workers[0].target), 3, True, True)) - self.assertEqual(self._task_context[EVALUATOR][0], + self.assertEqual(self._worker_context[EVALUATOR][0], ("fake_evaluator", 3, False, True)) |