aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-16 12:25:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 12:28:46 -0700
commit1326f33515dac82fbdd7ef502d1df5e96986fc12 (patch)
tree0a363b47cf4c6f68ecef108430ced4966b888860 /tensorflow/python/distribute
parent5360d7368713fa3d4e1aedc682ba8bbd3362deba (diff)
Use distribution strategy to configure distribute coordinator.
Add session_creator and a couple properties to worker context which then are used to configure monitored sessions. PiperOrigin-RevId: 209026599
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/BUILD9
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py204
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_context.py31
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py248
4 files changed, 400 insertions, 92 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 68d8b8d13b..16fbe3f4b5 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -41,3 +41,12 @@ py_test(
"//tensorflow/python:variables",
],
)
+
+py_library(
+ name = "distribute_coordinator_context",
+ srcs = [
+ "distribute_coordinator_context.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [],
+)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index fc9ca4ac4a..eb081b65fc 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""A unified and split coordinator for distributed TensorFlow."""
+"""A component for running distributed TensorFlow."""
from __future__ import absolute_import
from __future__ import division
@@ -24,6 +24,8 @@ import os
import threading
from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -43,23 +45,12 @@ class CoordinatorMode(object):
# client and connects to remote servers for training. Each remote server can
# use the distribute coordinator binary with task_type set correctly which
# will then turn into standard servers.
- SPLIT_CLIENT = 0
+ STANDALONE_CLIENT = "standalone_client"
# The distribute coordinator runs on each worker. It will run a standard
# server on each worker and optionally run the `worker_fn` that is configured
# to talk to its standard server.
- INDEPENDENT_WORKER = 1
-
-
-_worker_context = threading.local()
-
-
-def get_current_worker_context():
- """Returns the current task context."""
- try:
- return _worker_context.current
- except AttributeError:
- return None
+ INDEPENDENT_WORKER = "independent_worker"
class _Barrier(object):
@@ -113,14 +104,17 @@ class _WorkerContext(object):
"""
def __init__(self,
+ strategy,
cluster_spec,
task_type,
task_id,
+ session_config=None,
rpc_layer="grpc",
worker_barrier=None):
"""Initialize the worker context object.
Args:
+ strategy: a `DistributionStrategy` object.
cluster_spec: a ClusterSpec object. It can be empty or None in the local
training case.
task_type: a string indicating the role of the corresponding task, such as
@@ -128,14 +122,17 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
+ session_config: an optional @{tf.ConfigProto} object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
worker_barrier: optional, the barrier object for worker synchronization.
"""
+ self._strategy = strategy
self._cluster_spec = cluster_spec
self._task_type = task_type
self._task_id = task_id
+ self._session_config = session_config
self._worker_barrier = worker_barrier
self._rpc_layer = rpc_layer
self._master_target = self._get_master_target()
@@ -143,26 +140,31 @@ class _WorkerContext(object):
self._is_chief_node = self._is_chief()
def _debug_message(self):
- return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
- self._cluster_spec, self.task_type, self.task_id)
+ if self._cluster_spec:
+ return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
+ self._cluster_spec, self.task_type, self.task_id)
+ else:
+ return "[local]"
def __enter__(self):
- old_context = get_current_worker_context()
+ old_context = distribute_coordinator_context.get_current_worker_context()
if old_context:
raise ValueError(
"You cannot run distribute coordinator in a `worker_fn`.\t" +
self._debug_message())
- _worker_context.current = self
+ # pylint: disable=protected-access
+ distribute_coordinator_context._worker_context.current = self
def __exit__(self, unused_exception_type, unused_exception_value,
unused_traceback):
- _worker_context.current = None
+ # pylint: disable=protected-access
+ distribute_coordinator_context._worker_context.current = None
def _get_master_target(self):
"""Return the master target for a task."""
# If cluster_spec is None or empty, we use local master.
if not self._cluster_spec:
- return "local"
+ return ""
# If task_type is None, then it is in-graph replicated training. In this
# case we use the chief or first worker's master target.
@@ -207,6 +209,47 @@ class _WorkerContext(object):
self._debug_message())
self._worker_barrier.wait()
+ def session_creator(self,
+ scaffold=None,
+ config=None,
+ checkpoint_dir=None,
+ checkpoint_filename_with_path=None,
+ max_wait_secs=7200):
+ """Returns a session creator.
+
+ The returned session creator will be configured with the correct master
+ target and session configs. It will also run either init ops or ready ops
+ by querying the `strategy` object when `create_session` is called on it.
+
+ Args:
+ scaffold: A `Scaffold` used for gathering or building supportive ops. If
+ not specified a default one is created. It's used to finalize the graph.
+ config: `ConfigProto` proto used to configure the session.
+ checkpoint_dir: A string. Optional path to a directory where to restore
+ variables.
+ checkpoint_filename_with_path: Full file name path to the checkpoint file.
+ Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
+ specified.
+ max_wait_secs: Maximum time to wait for the session to become available.
+
+ Returns:
+ a descendant of SessionCreator.
+ """
+ # TODO(yuefengz): merge session config.
+ if self._strategy.should_init:
+ return monitored_session.ChiefSessionCreator(
+ scaffold,
+ master=self.master_target,
+ config=config or self._session_config,
+ checkpoint_dir=checkpoint_dir,
+ checkpoint_filename_with_path=checkpoint_filename_with_path)
+ else:
+ return monitored_session.WorkerSessionCreator(
+ scaffold,
+ master=self.master_target,
+ config=config or self._session_config,
+ max_wait_secs=max_wait_secs)
+
@property
def has_barrier(self):
"""Whether the barrier is set or not."""
@@ -247,21 +290,38 @@ class _WorkerContext(object):
"""Returns number of workers in the cluster, including chief."""
return self._num_workers
+ @property
+ def should_checkpoint(self):
+ """Whether to save checkpoint."""
+ return self._strategy.should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ """Whether to save summaries."""
+ return self._strategy.should_save_summary
+
def _run_single_worker(worker_fn,
+ strategy,
cluster_spec,
task_type,
task_id,
- rpc_layer,
+ session_config,
+ rpc_layer="",
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
- with _WorkerContext(
+ strategy = copy.deepcopy(strategy)
+ strategy.configure(session_config, cluster_spec, task_type, task_id)
+ context = _WorkerContext(
+ strategy,
cluster_spec,
task_type,
task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
- worker_barrier=worker_barrier):
- worker_fn()
+ worker_barrier=worker_barrier)
+ with context:
+ worker_fn(strategy)
def _run_std_server(cluster_spec=None,
@@ -280,13 +340,15 @@ def _run_std_server(cluster_spec=None,
return server
-def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
+def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
+ rpc_layer):
"""Runs a standalone client for between-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
})
@@ -298,7 +360,8 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
t = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, task_type, task_id),
+ args=(worker_fn, strategy, cluster_spec, task_type, task_id,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
"worker_barrier": worker_barrier
@@ -315,43 +378,53 @@ def _run_between_graph_client(worker_fn, cluster_spec, rpc_layer):
eval_thread.join()
-def _run_in_graph_client(worker_fn, cluster_spec, rpc_layer):
+def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
+ rpc_layer):
"""Runs a standalone client for in-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, cluster_spec, _TaskType.EVALUATOR, 0),
+ args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ session_config),
kwargs={
"rpc_layer": rpc_layer,
})
eval_thread.start()
- _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ _run_single_worker(
+ worker_fn,
+ strategy,
+ cluster_spec,
+ None,
+ None,
+ session_config,
+ rpc_layer=rpc_layer)
if eval_thread:
eval_thread.join()
-
-# TODO(yuefengz): propagate cluster_spec in the SPLIT_CLIENT mode.
+# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
- mode=CoordinatorMode.SPLIT_CLIENT,
+ strategy,
+ mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
task_type=None,
task_id=None,
- between_graph=False,
+ session_config=None,
rpc_layer="grpc"):
"""Runs the coordinator for distributed TensorFlow.
This function runs a split coordinator for distributed TensorFlow in its
- default mode, i.e the SPLIT_CLIENT mode. Given a `cluster_spec` specifying
- server addresses and their roles in a cluster, this coordinator will figure
- out how to set them up, give the underlying function the right targets for
- master sessions via a scope object and coordinate their training. The cluster
- consisting of standard servers needs to be brought up either with the standard
- server binary or with a binary running distribute coordinator with `task_type`
- set to non-client type which will then turn into standard servers.
+ default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
+ specifying server addresses and their roles in a cluster, this coordinator
+ will figure out how to set them up, give the underlying function the right
+ targets for master sessions via a scope object and coordinate their training.
+ The cluster consisting of standard servers needs to be brought up either with
+ the standard server binary or with a binary running distribute coordinator
+ with `task_type` set to non-client type which will then turn into standard
+ servers.
In addition to be the distribute coordinator, this is also the source of
configurations for each job in the distributed training. As there are multiple
@@ -370,6 +443,14 @@ def run_distribute_coordinator(worker_fn,
`worker_fn` depending whether it is between-graph training or in-graph
replicated training.
+ The `strategy` object is expected to be a DistributionStrategy object which
+ has implemented methods needed by distributed coordinator such as
+ `configure(session_config, cluster_spec, task_type, task_id)` which configures
+ the strategy object for a specific task and `should_init` property which
+ instructs the distribute coordinator whether to run init ops for a task. The
+ distribute coordinator will make a copy of the `strategy` object, call its
+ `configure` method and pass it to `worker_fn` as an argument.
+
The `worker_fn` defines the training logic and is called under a its own
worker context which can be accessed to via `get_current_worker_context`. A
worker context provides access to configurations for each task, e.g. the
@@ -413,16 +494,20 @@ def run_distribute_coordinator(worker_fn,
evaluation.
Args:
- worker_fn: the function to be called and given the access to a coordinator
- context object.
+ worker_fn: the function to be called. The function should accept a
+ `strategy` object and will be given access to a context object via a
+ context manager scope.
+ strategy: a DistributionStrategy object which specifying whether it should
+ run between-graph replicated training or not, whether to run init ops,
+ etc. This object will also be configured given `session_config`,
+ `cluster_spc`, `task_type` and `task_id`.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- between_graph: a boolean. It is only useful when `cluster_spec` is set and
- not empty. If true, it will use between-graph replicated training;
- otherwise it will use in-graph replicated training.
+ session_config: an optional @{tf.ConfigProto} object which will be passed
+ to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
Raises:
@@ -448,15 +533,18 @@ def run_distribute_coordinator(worker_fn,
if not cluster_spec:
# `mode` is ignored in the local case.
- _run_single_worker(worker_fn, None, None, None, rpc_layer)
- elif mode == CoordinatorMode.SPLIT_CLIENT:
+ _run_single_worker(worker_fn, strategy, None, None, None, session_config,
+ rpc_layer)
+ elif mode == CoordinatorMode.STANDALONE_CLIENT:
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
- if between_graph:
- _run_between_graph_client(worker_fn, cluster_spec, rpc_layer)
+ if strategy.between_graph:
+ _run_between_graph_client(worker_fn, strategy, cluster_spec,
+ session_config, rpc_layer)
else:
- _run_in_graph_client(worker_fn, cluster_spec, rpc_layer)
+ _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
+ rpc_layer)
else:
# If not a client job, run the standard server.
server = _run_std_server(
@@ -471,19 +559,21 @@ def run_distribute_coordinator(worker_fn,
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
- if between_graph:
+ if strategy.between_graph:
# All jobs run `worker_fn` if between-graph.
- _run_single_worker(worker_fn, cluster_spec, task_type, task_id,
- rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
# Only one node runs `worker_fn` if in-graph.
- context = _WorkerContext(cluster_spec, task_type, task_id, rpc_layer)
+ context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
if context.is_chief:
- _run_single_worker(worker_fn, cluster_spec, None, None, rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, None, None,
+ session_config, rpc_layer)
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
- _run_single_worker(worker_fn, cluster_spec, task_type, task_id, rpc_layer)
+ _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
+ session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)
diff --git a/tensorflow/python/distribute/distribute_coordinator_context.py b/tensorflow/python/distribute/distribute_coordinator_context.py
new file mode 100644
index 0000000000..dee65ce883
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_coordinator_context.py
@@ -0,0 +1,31 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The context retrieval method for distribute coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+_worker_context = threading.local()
+
+
+def get_current_worker_context():
+ """Returns the current task context."""
+ try:
+ return _worker_context.current
+ except AttributeError:
+ return None
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 319c29ba2f..97c6bdd15a 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for distribute coordinator."""
+"""Tests for Distribute Coordinator."""
from __future__ import absolute_import
from __future__ import division
@@ -37,6 +37,7 @@ except ImportError as _error:
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator
+from tensorflow.python.distribute import distribute_coordinator_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
@@ -44,17 +45,17 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+
CHIEF = distribute_coordinator._TaskType.CHIEF
WORKER = distribute_coordinator._TaskType.WORKER
PS = distribute_coordinator._TaskType.PS
EVALUATOR = distribute_coordinator._TaskType.EVALUATOR
-SPLIT_CLIENT = distribute_coordinator.CoordinatorMode.SPLIT_CLIENT
+STANDALONE_CLIENT = distribute_coordinator.CoordinatorMode.STANDALONE_CLIENT
INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
-RUN_STD_SERVER_METHOD = "tensorflow.python.distribute.distribute_coordinator._run_std_server"
-
NUM_WORKERS = 3
NUM_PS = 2
@@ -74,6 +75,57 @@ def _strip_protocol(target):
return target
+class MockStrategy(object):
+
+ def __init__(self,
+ between_graph=False,
+ should_init=None,
+ should_checkpoint=None,
+ should_save_summary=None):
+ self._between_graph = between_graph
+ self._should_init = should_init
+ self._should_checkpoint = should_checkpoint
+ self._should_save_summary = should_save_summary
+
+ @property
+ def between_graph(self):
+ return self._between_graph
+
+ def configure(self,
+ session_options=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del session_options, cluster_spec, task_type
+ if self._should_init is None:
+ if task_id == 0:
+ self._should_init = True
+ else:
+ self._should_init = False
+ if self._should_checkpoint is None:
+ if task_id == 0:
+ self._should_checkpoint = True
+ else:
+ self._should_checkpoint = False
+ if self._should_save_summary is None:
+ if task_id == 0:
+ self._should_save_summary = True
+ else:
+ self._should_save_summary = False
+
+ @property
+ def should_init(self):
+ return self._should_init
+
+ @property
+ def should_checkpoint(self):
+ return self._should_checkpoint
+
+ @property
+ def should_save_summary(self):
+ return self._should_save_summary
+
+
class MockServer(object):
def __init__(self):
@@ -108,6 +160,7 @@ class DistributeCoordinatorTestBase(test.TestCase):
self._result_correct = 0
self._lock = threading.Lock()
self._worker_context = {}
+ self._strategy_property = {}
self._std_servers = {}
self._barrier = distribute_coordinator._Barrier(NUM_WORKERS)
@@ -142,8 +195,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
return cluster_spec
- def _in_graph_worker_fn(self):
- context = distribute_coordinator.get_current_worker_context()
+ def _in_graph_worker_fn(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
xs = []
@@ -164,22 +217,23 @@ class DistributeCoordinatorTestBase(test.TestCase):
if result_value == expected:
self._result_correct += 1
- def _run_coordinator_in_thread(self, worker_fn, **kwargs):
+ def _run_coordinator_in_thread(self, worker_fn, strategy, **kwargs):
t = threading.Thread(
target=distribute_coordinator.run_distribute_coordinator,
- args=(worker_fn,),
+ args=(worker_fn, strategy),
kwargs=kwargs)
t.start()
return t
- def _run_multiple_coordinator_in_threads(self, worker_fn, cluster_spec,
- **kwargs):
+ def _run_multiple_coordinator_in_threads(self, worker_fn, strategy,
+ cluster_spec, **kwargs):
threads = {}
for task_type in cluster_spec.keys():
threads[task_type] = []
for task_id in range(len(cluster_spec[task_type])):
t = self._run_coordinator_in_thread(
worker_fn,
+ strategy,
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
@@ -187,8 +241,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
threads[task_type].append(t)
return threads
- def _between_graph_worker_fn(self):
- context = distribute_coordinator.get_current_worker_context()
+ def _between_graph_worker_fn(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
with self._test_session(target=context.master_target) as sess:
with ops.device("/job:ps/task:0"):
@@ -234,14 +288,50 @@ class DistributeCoordinatorTestBase(test.TestCase):
with self._lock:
self._result_correct += 1
- def _dump_worker_context(self):
+ def _between_graph_with_monitored_session(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
+ self.assertTrue(context is not None)
+ with ops.device("/job:ps/task:0"):
+ # TODO(yuefengz): investigate why not using resource variable will make
+ # the test flaky.
+ x = variable_scope.get_variable("x", initializer=10.0, use_resource=True)
+ with ops.device("/job:ps/task:1"):
+ y = variable_scope.get_variable("y", initializer=20.0, use_resource=True)
+
+ x_add = x.assign_add(2.0)
+ y_sub = y.assign_sub(2.0)
+ train_op = control_flow_ops.group([x_add, y_sub])
+
+ # The monitored session will run init or ready ops.
+ with monitored_session.MonitoredSession() as sess:
+ sess.run(train_op)
+
+ # Synchronize workers after one step to make sure they all have finished
+ # training.
+ if context.has_barrier:
+ context.wait_for_other_workers()
+ else:
+ self._barrier.wait()
+
+ x_val, y_val = sess.run([x, y])
+
+ self.assertEqual(x_val, 16.0)
+ self.assertEqual(y_val, 14.0)
+ if x_val == 16.0 and y_val == 14.0:
+ with self._lock:
+ self._result_correct += 1
+
+ def _dump_worker_context(self, strategy):
"""Dumps the propoerties of each worker context.
It dumps the context properties to a dict mapping from task_type to a list
of tuples of master_target, num_workers, is_chief and distribute_mode, where
the list is indexed by the task_id.
+
+ Args:
+ strategy: a `DistributionStrategy` object.
"""
- context = distribute_coordinator.get_current_worker_context()
+ context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
task_type = str(context.task_type)
task_id = context.task_id or 0
@@ -255,6 +345,25 @@ class DistributeCoordinatorTestBase(test.TestCase):
context.is_chief,
context.distributed_mode)
+ def _dump_strategy_property(self, strategy):
+ context = distribute_coordinator_context.get_current_worker_context()
+ self.assertTrue(context is not None)
+
+ self.assertEqual(context._strategy.should_init, strategy.should_init)
+ self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
+ self.assertEqual(context.should_save_summary, strategy.should_save_summary)
+
+ task_type = str(context.task_type)
+ task_id = context.task_id or 0
+ with self._lock:
+ if task_type not in self._strategy_property:
+ self._strategy_property[task_type] = []
+ while len(self._strategy_property[task_type]) <= task_id:
+ self._strategy_property[task_type].append(None)
+ self._strategy_property[task_type][task_id] = (
+ context._strategy.should_init, context.should_checkpoint,
+ context.should_save_summary)
+
def _run_mock_std_server(self,
session_config=None,
cluster_spec=None,
@@ -274,22 +383,32 @@ class DistributeCoordinatorTestBase(test.TestCase):
return server
-class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
+class DistributeCoordinatorTestStandaloneMode(DistributeCoordinatorTestBase):
- def testInGraphSplitMode(self):
- """Test it runs in-graph replication in split client mode."""
+ def testInGraphStandaloneMode(self):
+ """Test it runs in-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._in_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=False)
+ MockStrategy(between_graph=False),
+ cluster_spec=self._cluster_spec)
self.assertEqual(self._result_correct, 1)
def testBetweenGraph(self):
- """Test it runs between-graph replication in split client mode."""
+ """Test it runs between-graph replication in standalone client mode."""
distribute_coordinator.run_distribute_coordinator(
self._between_graph_worker_fn,
- cluster_spec=self._cluster_spec,
- between_graph=True)
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def testBetweenGraphWithMonitoredSession(self):
+ """Test monitored session in standalone client mode."""
+ distribute_coordinator.run_distribute_coordinator(
+ self._between_graph_with_monitored_session,
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
# Each finished worker will increment self._result_correct.
self.assertEqual(self._result_correct, NUM_WORKERS)
@@ -298,8 +417,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
- cluster_spec=self._cluster_spec,
- between_graph=True)
+ MockStrategy(between_graph=True),
+ cluster_spec=self._cluster_spec)
# There is only one type of task and there three such tasks.
self.assertEqual(len(self._worker_context), 1)
@@ -318,12 +437,30 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
self._worker_context[WORKER][2],
(_bytes_to_str(self._workers[2].target), NUM_WORKERS, False, True))
+ def testBetweenGraphStrategyProperties(self):
+ # Dumps properties of the strategy objects.
+ distribute_coordinator.run_distribute_coordinator(
+ self._dump_strategy_property,
+ MockStrategy(between_graph=True, should_init=True),
+ cluster_spec=self._cluster_spec)
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._strategy_property), 1)
+ self.assertTrue(WORKER in self._strategy_property)
+ self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right properties of should_init,
+ # should_checkpoint and should_save_summary.
+ self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+ self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+ self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
def testInGraphContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
- cluster_spec=self._cluster_spec,
- between_graph=False)
+ MockStrategy(between_graph=False),
+ cluster_spec=self._cluster_spec)
# There is only a "None" task in the dumped task context.
self.assertEqual(len(self._worker_context), 1)
@@ -339,7 +476,9 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
def testLocalContext(self):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
- self._dump_worker_context, cluster_spec=None, between_graph=True)
+ self._dump_worker_context,
+ MockStrategy(between_graph=False),
+ cluster_spec=None)
# There is only a "None" task.
self.assertEqual(len(self._worker_context), 1)
@@ -348,7 +487,7 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Check whether each task has the right master_target, num_workers, is_chief
# and distributed_mode.
- self.assertEqual(self._worker_context["None"][0], ("local", 0, True, False))
+ self.assertEqual(self._worker_context["None"][0], ("", 0, True, False))
def testBetweenGraphContextWithChief(self):
# Adds a chief node, so there are NUM_WORKERS + 1 workers in total.
@@ -358,8 +497,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
+ MockStrategy(between_graph=True),
cluster_spec=cluster_spec,
- between_graph=True,
rpc_layer="grpc")
# There are one CHIEF and three workers.
@@ -391,8 +530,8 @@ class DistributeCoordinatorTestSplitMode(DistributeCoordinatorTestBase):
# Dumps the task contexts to the self._worker_context dict.
distribute_coordinator.run_distribute_coordinator(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec=cluster_spec,
- between_graph=False,
rpc_layer=None)
# There are one "None" task and one EVALUATOR task.
@@ -417,8 +556,8 @@ class DistributeCoordinatorTestInpendentWorkerMode(
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
threads = self._run_multiple_coordinator_in_threads(
self._in_graph_worker_fn,
+ MockStrategy(between_graph=False),
cluster_spec,
- between_graph=False,
mode=INDEPENDENT_WORKER)
threads[WORKER][0].join()
self.assertEqual(self._result_correct, 1)
@@ -428,8 +567,22 @@ class DistributeCoordinatorTestInpendentWorkerMode(
num_workers=NUM_WORKERS, num_ps=NUM_PS)
threads = self._run_multiple_coordinator_in_threads(
self._between_graph_worker_fn,
+ MockStrategy(between_graph=True),
+ cluster_spec,
+ mode=INDEPENDENT_WORKER)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # Each finished worker will increment self._result_correct.
+ self.assertEqual(self._result_correct, NUM_WORKERS)
+
+ def testBetweenGraphWithMonitoredSession(self):
+ cluster_spec = self._create_cluster_spec(
+ num_workers=NUM_WORKERS, num_ps=NUM_PS)
+ threads = self._run_multiple_coordinator_in_threads(
+ self._between_graph_with_monitored_session,
+ MockStrategy(between_graph=True),
cluster_spec,
- between_graph=True,
mode=INDEPENDENT_WORKER)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -444,9 +597,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=True),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=True,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -476,6 +629,31 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertFalse(self._std_servers[WORKER][1].joined)
self.assertFalse(self._std_servers[WORKER][2].joined)
+ def testBetweenGraphStrategyProperties(self):
+ cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
+ # Dumps properties of the strategy objects.
+ with test.mock.patch.object(distribute_coordinator, "_run_std_server",
+ self._run_mock_std_server):
+ threads = self._run_multiple_coordinator_in_threads(
+ self._dump_strategy_property,
+ MockStrategy(between_graph=True, should_init=True),
+ cluster_spec,
+ mode=INDEPENDENT_WORKER,
+ rpc_layer=None)
+ for task_id in range(NUM_WORKERS):
+ threads[WORKER][task_id].join()
+
+ # There is only one type of task and there three such tasks.
+ self.assertEqual(len(self._strategy_property), 1)
+ self.assertTrue(WORKER in self._strategy_property)
+ self.assertEqual(len(self._strategy_property[WORKER]), NUM_WORKERS)
+
+ # Check whether each task has the right properties of should_init,
+ # should_checkpoint and should_save_summary.
+ self.assertEqual(self._strategy_property[WORKER][0], (True, True, True))
+ self.assertEqual(self._strategy_property[WORKER][1], (True, False, False))
+ self.assertEqual(self._strategy_property[WORKER][2], (True, False, False))
+
def testInGraphContext(self):
cluster_spec = self._create_cluster_spec(num_workers=NUM_WORKERS)
# Dumps the task contexts and std server arguments.
@@ -483,9 +661,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()
@@ -519,9 +697,9 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self._run_mock_std_server):
threads = self._run_multiple_coordinator_in_threads(
self._dump_worker_context,
+ MockStrategy(between_graph=False),
cluster_spec,
mode=INDEPENDENT_WORKER,
- between_graph=False,
rpc_layer=None)
for task_id in range(NUM_WORKERS):
threads[WORKER][task_id].join()