aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute/distribute_coordinator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/distribute/distribute_coordinator.py')
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py140
1 files changed, 116 insertions, 24 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 46cdd64a6e..d9f78150b9 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -24,9 +24,10 @@ import os
import threading
import time
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -238,19 +239,26 @@ class _WorkerContext(object):
Returns:
a descendant of SessionCreator.
"""
- # TODO(yuefengz): merge session config.
- if self._strategy.should_init:
+ if config:
+ session_config = copy.deepcopy(config)
+ session_config.MergeFrom(self._session_config)
+ else:
+ session_config = self._session_config
+
+ if not self._strategy or self._strategy.should_init:
+ logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
- config=config or self._session_config,
+ config=session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
+ logging.info("Creating worker session creator with config: %r", config)
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
- config=config or self._session_config,
+ config=session_config,
max_wait_secs=max_wait_secs)
@property
@@ -313,12 +321,17 @@ def _run_single_worker(worker_fn,
rpc_layer="",
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
+ session_config = copy.deepcopy(session_config)
strategy = copy.deepcopy(strategy)
# If there is an EVALUATOR task, we run single-machine eval on that task.
if task_type == _TaskType.EVALUATOR:
- strategy.configure(session_config)
+ # It is possible to not have a strategy object for EVALUATOR task.
+ if strategy:
+ strategy.configure(session_config)
else:
+ assert strategy
strategy.configure(session_config, cluster_spec, task_type, task_id)
+
context = _WorkerContext(
strategy,
cluster_spec,
@@ -331,6 +344,25 @@ def _run_single_worker(worker_fn,
worker_fn(strategy)
+def _split_cluster_for_evaluator(cluster_spec, task_type):
+ """Split the cluster for evaluator since it needn't talk to other tasks."""
+ # Splitting the cluster is important to prevent the evaluator from talking to
+ # other tasks in the cluster. Since we allow evaluator not to use
+ # distribution strategies and as a result ops in the evalauator task may have
+ # unspecified devices. Those ops may end up on other tasks if we don't split
+ # the cluster.
+ new_cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec).as_dict()
+ if task_type == _TaskType.EVALUATOR:
+ assert _TaskType.EVALUATOR in new_cluster_spec
+ new_cluster_spec = {
+ _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
+ }
+ else:
+ new_cluster_spec.pop(_TaskType.EVALUATOR, None)
+ return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
+
+
def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
@@ -338,16 +370,19 @@ def _run_std_server(cluster_spec=None,
rpc_layer=None,
environment=None):
"""Runs a standard server."""
+ assert cluster_spec
+ target = cluster_spec.task_address(task_type, task_id)
+ if rpc_layer:
+ target = rpc_layer + "://" + target
class _FakeServer(object):
"""A fake server that runs a master session."""
def start(self):
- assert cluster_spec
- target = cluster_spec.task_address(task_type, task_id)
- if rpc_layer:
- target = rpc_layer + "://" + target
# A tensorflow server starts when a remote session is created.
+ logging.info(
+ "Creating a remote session to start a TensorFlow server, "
+ "target = %r, session_config=%r", target, session_config)
session.Session(target=target, config=session_config)
def join(self):
@@ -359,6 +394,13 @@ def _run_std_server(cluster_spec=None,
server.start()
return server
else:
+ if session_config:
+ logging.info(
+ "Starting standard TensorFlow server, target = %r, session_config= "
+ "%r", target, session_config)
+ else:
+ logging.info("Starting standard TensorFlow server, target = %r", target)
+ cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
server = server_lib.Server(
cluster_spec,
job_name=task_type,
@@ -376,7 +418,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
+ args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@@ -432,6 +474,33 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
if eval_thread:
eval_thread.join()
+
+def _configure_session_config_for_std_servers(
+ strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
+ # pylint: disable=g-doc-args
+ """Call strategy's `configure` to mutate the session_config.
+
+ The session_config is currently needed as default config for a TensorFlow
+ server. In the future, we should be able to remove this method and only pass
+ the session config to a client session.
+ """
+ if task_type == _TaskType.EVALUATOR:
+ if eval_strategy:
+ eval_strategy.configure(session_config=session_config)
+ else:
+ # The strategy may be shared in standalone client mode.
+ strategy = copy.deepcopy(strategy)
+ strategy.configure(
+ session_config=session_config,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ # Remove the device filters specific to the strategy, so that the
+ # TensorFlow server brought up with one strategy can be used by other
+ # strategies. The device filters can be set in the client side as well.
+ del session_config.device_filters[:]
+
+
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
@@ -533,8 +602,10 @@ def run_distribute_coordinator(worker_fn,
strategy: a DistributionStrategy object which specifying whether it should
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
- `cluster_spc`, `task_type` and `task_id`.
- eval_fn: optional function for "evaluator" task.
+ `cluster_spec`, `task_type` and `task_id`.
+ eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
+ in but a "evaluator" task found in the `cluster_spec`, the `worker_fn`
+ will be used for this task.
eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
@@ -558,17 +629,17 @@ def run_distribute_coordinator(worker_fn,
task_id = int(task_env.get("index", task_id))
if cluster_spec:
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
# TODO(yuefengz): validate cluster_spec.
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
environment = tf_config.get("environment", None)
+ # Setting the session config is necessary for some strategies such
+ # CollectiveAllReduceStrategy.
+ session_config = session_config or config_pb2.ConfigProto(
+ allow_soft_placement=True)
+
if cluster_spec:
logging.info(
"Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
@@ -581,11 +652,18 @@ def run_distribute_coordinator(worker_fn,
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
- _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
+ _run_single_worker(eval_fn, eval_strategy, None, None, None,
session_config, rpc_layer)
+ else:
+ logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
elif mode == CoordinatorMode.STANDALONE_CLIENT:
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
eval_fn = eval_fn or worker_fn
- eval_strategy = eval_strategy or strategy
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
# The client must know the cluster but servers in the cluster don't have to
# know the client.
@@ -598,10 +676,14 @@ def run_distribute_coordinator(worker_fn,
cluster_spec, session_config, rpc_layer)
else:
# If not a client job, run the standard server.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
environment=environment)
server.join()
@@ -609,14 +691,24 @@ def run_distribute_coordinator(worker_fn,
if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode)
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
eval_fn = eval_fn or worker_fn
- eval_strategy = eval_strategy or strategy
-
- # Every one starts a standard server.
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
+
+ # Every one starts a standard server, get session config from `configure`
+ # method.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
environment=environment)