From ab96371eaea4cc5b2f9c431eec455a1cf4be7c1c Mon Sep 17 00:00:00 2001 From: Yuefeng Zhou Date: Thu, 30 Aug 2018 18:20:57 -0700 Subject: Set session configs for ParameterServerStrategy and CollectiveAllReduceStrategy in their configure methods. Configure the session configs in DistributeCoordinator. Allow eval strategy to be None in DistributeCoordinator. Add more loggings. PiperOrigin-RevId: 211017587 --- .../python/distribute/distribute_coordinator.py | 140 +++++++++++++++++---- 1 file changed, 116 insertions(+), 24 deletions(-) (limited to 'tensorflow/python/distribute/distribute_coordinator.py') diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index 46cdd64a6e..d9f78150b9 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -24,9 +24,10 @@ import os import threading import time -from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.distribute import distribute_coordinator_context +from tensorflow.python.distribute import multi_worker_util from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import monitored_session from tensorflow.python.training import server_lib @@ -238,19 +239,26 @@ class _WorkerContext(object): Returns: a descendant of SessionCreator. """ - # TODO(yuefengz): merge session config. - if self._strategy.should_init: + if config: + session_config = copy.deepcopy(config) + session_config.MergeFrom(self._session_config) + else: + session_config = self._session_config + + if not self._strategy or self._strategy.should_init: + logging.info("Creating chief session creator with config: %r", config) return monitored_session.ChiefSessionCreator( scaffold, master=self.master_target, - config=config or self._session_config, + config=session_config, checkpoint_dir=checkpoint_dir, checkpoint_filename_with_path=checkpoint_filename_with_path) else: + logging.info("Creating worker session creator with config: %r", config) return monitored_session.WorkerSessionCreator( scaffold, master=self.master_target, - config=config or self._session_config, + config=session_config, max_wait_secs=max_wait_secs) @property @@ -313,12 +321,17 @@ def _run_single_worker(worker_fn, rpc_layer="", worker_barrier=None): """Runs a single worker by calling `worker_fn` under context.""" + session_config = copy.deepcopy(session_config) strategy = copy.deepcopy(strategy) # If there is an EVALUATOR task, we run single-machine eval on that task. if task_type == _TaskType.EVALUATOR: - strategy.configure(session_config) + # It is possible to not have a strategy object for EVALUATOR task. + if strategy: + strategy.configure(session_config) else: + assert strategy strategy.configure(session_config, cluster_spec, task_type, task_id) + context = _WorkerContext( strategy, cluster_spec, @@ -331,6 +344,25 @@ def _run_single_worker(worker_fn, worker_fn(strategy) +def _split_cluster_for_evaluator(cluster_spec, task_type): + """Split the cluster for evaluator since it needn't talk to other tasks.""" + # Splitting the cluster is important to prevent the evaluator from talking to + # other tasks in the cluster. Since we allow evaluator not to use + # distribution strategies and as a result ops in the evalauator task may have + # unspecified devices. Those ops may end up on other tasks if we don't split + # the cluster. + new_cluster_spec = multi_worker_util.normalize_cluster_spec( + cluster_spec).as_dict() + if task_type == _TaskType.EVALUATOR: + assert _TaskType.EVALUATOR in new_cluster_spec + new_cluster_spec = { + _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR] + } + else: + new_cluster_spec.pop(_TaskType.EVALUATOR, None) + return multi_worker_util.normalize_cluster_spec(new_cluster_spec) + + def _run_std_server(cluster_spec=None, task_type=None, task_id=None, @@ -338,16 +370,19 @@ def _run_std_server(cluster_spec=None, rpc_layer=None, environment=None): """Runs a standard server.""" + assert cluster_spec + target = cluster_spec.task_address(task_type, task_id) + if rpc_layer: + target = rpc_layer + "://" + target class _FakeServer(object): """A fake server that runs a master session.""" def start(self): - assert cluster_spec - target = cluster_spec.task_address(task_type, task_id) - if rpc_layer: - target = rpc_layer + "://" + target # A tensorflow server starts when a remote session is created. + logging.info( + "Creating a remote session to start a TensorFlow server, " + "target = %r, session_config=%r", target, session_config) session.Session(target=target, config=session_config) def join(self): @@ -359,6 +394,13 @@ def _run_std_server(cluster_spec=None, server.start() return server else: + if session_config: + logging.info( + "Starting standard TensorFlow server, target = %r, session_config= " + "%r", target, session_config) + else: + logging.info("Starting standard TensorFlow server, target = %r", target) + cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type) server = server_lib.Server( cluster_spec, job_name=task_type, @@ -376,7 +418,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0, + args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, session_config), kwargs={ "rpc_layer": rpc_layer, @@ -432,6 +474,33 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, if eval_thread: eval_thread.join() + +def _configure_session_config_for_std_servers( + strategy, eval_strategy, session_config, cluster_spec, task_type, task_id): + # pylint: disable=g-doc-args + """Call strategy's `configure` to mutate the session_config. + + The session_config is currently needed as default config for a TensorFlow + server. In the future, we should be able to remove this method and only pass + the session config to a client session. + """ + if task_type == _TaskType.EVALUATOR: + if eval_strategy: + eval_strategy.configure(session_config=session_config) + else: + # The strategy may be shared in standalone client mode. + strategy = copy.deepcopy(strategy) + strategy.configure( + session_config=session_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + # Remove the device filters specific to the strategy, so that the + # TensorFlow server brought up with one strategy can be used by other + # strategies. The device filters can be set in the client side as well. + del session_config.device_filters[:] + + # TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode. # TODO(yuefengz): we may need a smart way to figure out whether the current task # is the special task when we support cluster_spec propagation. @@ -533,8 +602,10 @@ def run_distribute_coordinator(worker_fn, strategy: a DistributionStrategy object which specifying whether it should run between-graph replicated training or not, whether to run init ops, etc. This object will also be configured given `session_config`, - `cluster_spc`, `task_type` and `task_id`. - eval_fn: optional function for "evaluator" task. + `cluster_spec`, `task_type` and `task_id`. + eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed + in but a "evaluator" task found in the `cluster_spec`, the `worker_fn` + will be used for this task. eval_strategy: optional DistributionStrategy object for "evaluator" task. mode: in which mode this distribute coordinator runs. cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles @@ -558,17 +629,17 @@ def run_distribute_coordinator(worker_fn, task_id = int(task_env.get("index", task_id)) if cluster_spec: - if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)): - cluster_spec = server_lib.ClusterSpec(cluster_spec) - elif not isinstance(cluster_spec, server_lib.ClusterSpec): - raise ValueError( - "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " - "`tf.train.ClusterDef` object") + cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) # TODO(yuefengz): validate cluster_spec. rpc_layer = tf_config.get("rpc_layer", rpc_layer) environment = tf_config.get("environment", None) + # Setting the session config is necessary for some strategies such + # CollectiveAllReduceStrategy. + session_config = session_config or config_pb2.ConfigProto( + allow_soft_placement=True) + if cluster_spec: logging.info( "Running Distribute Coordinator with mode = %r, cluster_spec = %r, " @@ -581,11 +652,18 @@ def run_distribute_coordinator(worker_fn, _run_single_worker(worker_fn, strategy, None, None, None, session_config, rpc_layer) if eval_fn: - _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None, + _run_single_worker(eval_fn, eval_strategy, None, None, None, session_config, rpc_layer) + else: + logging.warning("Skipped evaluation since `eval_fn` is not passed in.") elif mode == CoordinatorMode.STANDALONE_CLIENT: + if not eval_fn: + logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " + "used if an \"evaluator\" task exists in the cluster.") eval_fn = eval_fn or worker_fn - eval_strategy = eval_strategy or strategy + if not eval_strategy: + logging.warning("`eval_strategy` is not passed in. No distribution " + "strategy will be used for evaluation.") # The client must know the cluster but servers in the cluster don't have to # know the client. @@ -598,10 +676,14 @@ def run_distribute_coordinator(worker_fn, cluster_spec, session_config, rpc_layer) else: # If not a client job, run the standard server. + _configure_session_config_for_std_servers(strategy, eval_strategy, + session_config, cluster_spec, + task_type, task_id) server = _run_std_server( cluster_spec=cluster_spec, task_type=task_type, task_id=task_id, + session_config=session_config, rpc_layer=rpc_layer, environment=environment) server.join() @@ -609,14 +691,24 @@ def run_distribute_coordinator(worker_fn, if mode != CoordinatorMode.INDEPENDENT_WORKER: raise ValueError("Unexpected coordinator mode: %r" % mode) + if not eval_fn: + logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " + "used if an \"evaluator\" task exists in the cluster.") eval_fn = eval_fn or worker_fn - eval_strategy = eval_strategy or strategy - - # Every one starts a standard server. + if not eval_strategy: + logging.warning("`eval_strategy` is not passed in. No distribution " + "strategy will be used for evaluation.") + + # Every one starts a standard server, get session config from `configure` + # method. + _configure_session_config_for_std_servers(strategy, eval_strategy, + session_config, cluster_spec, + task_type, task_id) server = _run_std_server( cluster_spec=cluster_spec, task_type=task_type, task_id=task_id, + session_config=session_config, rpc_layer=rpc_layer, environment=environment) -- cgit v1.2.3