aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-07-31 22:40:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 22:45:03 -0700
commit209efdd2bcc4d7f0d5aeda77fb11f04c37d484e6 (patch)
tree4ecd0b71391e91feedf85c04d681b466b61f5d45 /tensorflow/python/distribute
parent7ca6ee15555db77c09861fc7e84e5181001da07d (diff)
Rename coordinator context to worker context.
PiperOrigin-RevId: 206882388
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py46
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py110
2 files changed, 79 insertions, 77 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 04c50dbafc..b5b4b28033 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -34,13 +34,13 @@ class _TaskType(object):
EVALUATOR = "evaluator"
-_coordinator_context = threading.local()
+_worker_context = threading.local()
-def get_current_coordinator_context():
- """Returns the current coordinator context."""
+def get_current_worker_context():
+ """Returns the current task context."""
try:
- return _coordinator_context.current
+ return _worker_context.current
except AttributeError:
return None
@@ -86,13 +86,13 @@ def _get_num_workers(cluster_spec):
cluster_spec.as_dict().get(_TaskType.CHIEF, []))
-class _CoordinatorContext(object):
- """The coordinator context class.
+class _WorkerContext(object):
+ """The worker context class.
This context object provides configuration information for each task. One
- context manager with a coordinator context object will be created per
- invocation to the `worker_fn` where `get_current_coordinator_context` can be
- called to access the coordinator context object.
+ context manager with a worker context object will be created per
+ invocation to the `worker_fn` where `get_current_worker_context` can be called
+ to access the worker context object.
"""
def __init__(self,
@@ -102,7 +102,7 @@ class _CoordinatorContext(object):
between_graph=False,
rpc_layer="grpc",
worker_barrier=None):
- """Initialize the coordinator context object.
+ """Initialize the worker context object.
Args:
cluster_spec: a ClusterSpec object. It can be empty or None in the local
@@ -139,15 +139,15 @@ class _CoordinatorContext(object):
self._is_chief_node = self._is_chief()
def __enter__(self):
- old_context = get_current_coordinator_context()
+ old_context = get_current_worker_context()
if old_context:
raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.")
- _coordinator_context.current = self
+ _worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value,
unused_traceback):
- _coordinator_context.current = None
+ _worker_context.current = None
def _get_master_target(self):
"""Return the master target for a task."""
@@ -195,7 +195,7 @@ class _CoordinatorContext(object):
"""
if not self._worker_barrier:
raise ValueError(
- "`worker_barrier is not set in the coordinator context.`")
+ "`worker_barrier is not set in the worker context.`")
self._worker_barrier.wait()
@property
@@ -236,7 +236,7 @@ class _CoordinatorContext(object):
def _run(worker_fn, cluster_spec, task_type, task_id, between_graph, rpc_layer,
worker_barrier):
- with _CoordinatorContext(cluster_spec, task_type, task_id, between_graph,
+ with _WorkerContext(cluster_spec, task_type, task_id, between_graph,
rpc_layer, worker_barrier):
worker_fn()
@@ -266,13 +266,13 @@ def run_distribute_coordinator(worker_fn,
this `worker_fn`.
The `worker_fn` defines the training logic and is called under a its own
- coordinator context which can be accessed to via
- `get_current_coordinator_context`. A coordinator context provides access to
- configurations for each task, e.g. the task_type, task_id, master target and
- so on. Since `worker_fn` will be called in a thread and possibly multiple
- times, caller should be careful when it accesses global data. For example, it
- is unsafe to define flags in a `worker_fn` or to define different environment
- variables for different `worker_fn`s.
+ worker context which can be accessed to via `get_current_worker_context`. A
+ worker context provides access to configurations for each task, e.g. the
+ task_type, task_id, master target and so on. Since `worker_fn` will be called
+ in a thread and possibly multiple times, caller should be careful when it
+ accesses global data. For example, it is unsafe to define flags in a
+ `worker_fn` or to define different environment variables for different
+ `worker_fn`s.
The `worker_fn` for the between-graph replication is defined as if there are
only one worker corresponding to the `worker_fn` and possibly ps jobs. It
@@ -287,7 +287,7 @@ def run_distribute_coordinator(worker_fn,
high-level APIs, to change a program to use this coordinator, wrap everything
in a the program after global data definitions such as commandline flag
definition into the `worker_fn` and get task-specific configurations from
- the coordinator context.
+ the worker context.
The `cluster_spec` can be either passed by the argument or parsed from the
"TF_CONFIG" envrionment variable. Example of a TF_CONFIG:
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 82fd823352..d7ffeb56a5 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -67,7 +67,7 @@ class DistributeCoordinatorTest(test.TestCase):
def setUp(self):
self._result_correct = 0
self._lock = threading.Lock()
- self._task_context = {}
+ self._worker_context = {}
@contextlib.contextmanager
def _test_session(self, target):
@@ -77,7 +77,7 @@ class DistributeCoordinatorTest(test.TestCase):
yield sess
def _in_graph_worker_fn(self):
- context = distribute_coordinator.get_current_coordinator_context()
+ context = distribute_coordinator.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
@@ -107,7 +107,7 @@ class DistributeCoordinatorTest(test.TestCase):
self.assertEqual(self._result_correct, 1)
def _between_graph_worker_fn(self):
- context = distribute_coordinator.get_current_coordinator_context()
+ context = distribute_coordinator.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
@@ -153,113 +153,113 @@ class DistributeCoordinatorTest(test.TestCase):
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
- def _dump_task_context(self):
- """Dumps the propoerties of each coordinator context.
+ def _dump_worker_context(self):
+ """Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
"""
- context = distribute_coordinator.get_current_coordinator_context()
+ context = distribute_coordinator.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
with self._lock:
- if task_type not in self._task_context:
- self._task_context[task_type] = []
- while len(self._task_context[task_type]) <= task_id:
- self._task_context[task_type].append(None)
- self._task_context[task_type][task_id] = (context.master_target,
- context.num_workers,
- context.is_chief,
- context.distributed_mode)
+ if task_type not in self._worker_context:
+ self._worker_context[task_type] = []
+ while len(self._worker_context[task_type]) <= task_id:
+ self._worker_context[task_type].append(None)
+ self._worker_context[task_type][task_id] = (context.master_target,
+ context.num_workers,
+ context.is_chief,
+ context.distributed_mode)
def testBetweenGraphContext(self):
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context,
+ self._dump_worker_context,
cluster_spec=self._cluster_spec,
between_graph=True)
# There is only one type of task and there three such tasks.
- self.assertEqual(len(self._task_context), 1)
- self.assertTrue(WORKER in self._task_context)
- self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue(WORKER in self._worker_context)
+ self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
- self._task_context[WORKER][0],
+ self._worker_context[WORKER][0],
(_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
self.assertEqual(
- self._task_context[WORKER][1],
+ self._worker_context[WORKER][1],
(_bytes_to_str(self._workers[1].target), NUM_WORKERS, False, True))
self.assertEqual(
- self._task_context[WORKER][2],
+ self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
def testInGraphContext(self):
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context,
+ self._dump_worker_context,
cluster_spec=self._cluster_spec,
between_graph=False)
# There is only a "None" task in the dumped task context.
- self.assertEqual(len(self._task_context), 1)
- self.assertTrue("None" in self._task_context)
- self.assertEqual(len(self._task_context["None"]), 1)
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue("None" in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
self.assertEqual(
- self._task_context["None"][0],
+ self._worker_context["None"][0],
(_bytes_to_str(self._workers[0].target), NUM_WORKERS, True, True))
def testLocalContext(self):
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context, cluster_spec=None, between_graph=True)
+ self._dump_worker_context, cluster_spec=None, between_graph=True)
# There is only a "None" task.
- self.assertEqual(len(self._task_context), 1)
- self.assertTrue("None" in self._task_context)
- self.assertEqual(len(self._task_context["None"]), 1)
+ self.assertEqual(len(self._worker_context), 1)
+ self.assertTrue("None" in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._task_context["None"][0], ("local", 0, True, False))
+ self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
cluster_spec = copy.deepcopy(self._cluster_spec)
cluster_spec[CHIEF] = ["fake_chief"]
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context,
+ self._dump_worker_context,
cluster_spec=cluster_spec,
between_graph=True,
rpc_layer="grpc")
# There are one CHIEF and three workers.
- self.assertEqual(len(self._task_context), 2)
- self.assertTrue(CHIEF in self._task_context)
- self.assertTrue(WORKER in self._task_context)
- self.assertEqual(len(self._task_context[CHIEF]), 1)
- self.assertEqual(len(self._task_context[WORKER]), NUM_WORKERS)
+ self.assertEqual(len(self._worker_context), 2)
+ self.assertTrue(CHIEF in self._worker_context)
+ self.assertTrue(WORKER in self._worker_context)
+ self.assertEqual(len(self._worker_context[CHIEF]), 1)
+ self.assertEqual(len(self._worker_context[WORKER]), NUM_WORKERS)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._task_context[CHIEF][0],
+ self.assertEqual(self._worker_context[CHIEF][0],
("grpc://fake_chief", 4, True, True))
- self.assertEqual(self._task_context[WORKER][0],
+ self.assertEqual(self._worker_context[WORKER][0],
("grpc://" + _bytes_to_str(self._workers[0].target),
NUM_WORKERS + 1, False, True))
- self.assertEqual(self._task_context[WORKER][1],
+ self.assertEqual(self._worker_context[WORKER][1],
("grpc://" + _bytes_to_str(self._workers[1].target),
NUM_WORKERS + 1, False, True))
- self.assertEqual(self._task_context[WORKER][2],
+ self.assertEqual(self._worker_context[WORKER][2],
("grpc://" + _bytes_to_str(self._workers[2].target),
NUM_WORKERS + 1, False, True))
@@ -268,22 +268,24 @@ class DistributeCoordinatorTest(test.TestCase):
cluster_spec = copy.deepcopy(self._cluster_spec)
cluster_spec[EVALUATOR] = ["fake_evaluator"]
- # Dumps the task contexts to the self._task_context dict.
+ # Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_task_context, cluster_spec=cluster_spec, between_graph=False)
+ self._dump_worker_context,
+ cluster_spec=cluster_spec,
+ between_graph=False)
# There are one "None" task and one EVALUATOR task.
- self.assertEqual(len(self._task_context), 2)
- self.assertTrue("None" in self._task_context)
- self.assertTrue(EVALUATOR in self._task_context)
- self.assertEqual(len(self._task_context["None"]), 1)
- self.assertEqual(len(self._task_context[EVALUATOR]), 1)
+ self.assertEqual(len(self._worker_context), 2)
+ self.assertTrue("None" in self._worker_context)
+ self.assertTrue(EVALUATOR in self._worker_context)
+ self.assertEqual(len(self._worker_context["None"]), 1)
+ self.assertEqual(len(self._worker_context[EVALUATOR]), 1)
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._task_context["None"][0],
+ self.assertEqual(self._worker_context["None"][0],
(_bytes_to_str(self._workers[0].target), 3, True, True))
- self.assertEqual(self._task_context[EVALUATOR][0],
+ self.assertEqual(self._worker_context[EVALUATOR][0],
("fake_evaluator", 3, False, True))