diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-08-24 19:14:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-24 19:21:07 -0700 |
commit | ca94990804cf5326c0f6f46d75c96e0f0e240366 (patch) | |
tree | ddb7499d3404d1c9f08b8d2364053111101a677f /tensorflow/python/distribute | |
parent | 9599b473035fb9f38959608f32180e22216c7dbc (diff) |
Add an option to RunConfig and train_and_evaluate to run distribute coordinator.
This is necessary to run multi-worker MirroredStrategy and CollectiveAllReduceStrategy with estimator.
PiperOrigin-RevId: 210192378
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r-- | tensorflow/python/distribute/BUILD | 33 | ||||
-rw-r--r-- | tensorflow/python/distribute/distribute_config.py | 45 | ||||
-rw-r--r-- | tensorflow/python/distribute/distribute_coordinator.py | 53 | ||||
-rw-r--r-- | tensorflow/python/distribute/estimator_training.py | 264 |
4 files changed, 378 insertions, 17 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 98ef9bf492..ebfcd085e6 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -9,6 +9,25 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") py_library( + name = "distribute", + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":distribute_config", + ":distribute_coordinator", + ":distribute_coordinator_context", + ], +) + +py_library( + name = "distribute_config", + srcs = [ + "distribute_config.py", + ], + deps = [], +) + +py_library( name = "distribute_coordinator", srcs = [ "distribute_coordinator.py", @@ -81,3 +100,17 @@ py_test( "@absl_py//absl/testing:parameterized", ], ) + +# Used only by estimator. +py_library( + name = "estimator_training", + srcs = [ + "estimator_training.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":distribute_coordinator", + ":distribute_coordinator_context", + "//tensorflow/python:training", + ], +) diff --git a/tensorflow/python/distribute/distribute_config.py b/tensorflow/python/distribute/distribute_config.py new file mode 100644 index 0000000000..fac35742fe --- /dev/null +++ b/tensorflow/python/distribute/distribute_config.py @@ -0,0 +1,45 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A configure tuple for high-level APIs for running distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + + +class DistributeConfig( + collections.namedtuple( + 'DistributeConfig', + ['train_distribute', 'eval_distribute', 'remote_cluster'])): + """A config tuple for distribution strategies. + + Attributes: + train_distribute: a `DistributionStrategy` object for training. + eval_distribute: an optional `DistributionStrategy` object for + evaluation. + remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying + the cluster configurations. If this is given, the `train_and_evaluate` + method will be running as a standalone client which connects to the + cluster for training. + """ + + def __new__(cls, + train_distribute=None, + eval_distribute=None, + remote_cluster=None): + return super(DistributeConfig, cls).__new__(cls, train_distribute, + eval_distribute, remote_cluster) diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py index eb081b65fc..9cf0b3b7a6 100644 --- a/tensorflow/python/distribute/distribute_coordinator.py +++ b/tensorflow/python/distribute/distribute_coordinator.py @@ -311,7 +311,11 @@ def _run_single_worker(worker_fn, worker_barrier=None): """Runs a single worker by calling `worker_fn` under context.""" strategy = copy.deepcopy(strategy) - strategy.configure(session_config, cluster_spec, task_type, task_id) + # If there is an EVALUATOR task, we run single-machine eval on that task. + if task_type == _TaskType.EVALUATOR: + strategy.configure(session_config) + else: + strategy.configure(session_config, cluster_spec, task_type, task_id) context = _WorkerContext( strategy, cluster_spec, @@ -340,14 +344,14 @@ def _run_std_server(cluster_spec=None, return server -def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config, - rpc_layer): +def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer): """Runs a standalone client for between-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0, session_config), kwargs={ "rpc_layer": rpc_layer, @@ -378,14 +382,14 @@ def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config, eval_thread.join() -def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, - rpc_layer): +def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer): """Runs a standalone client for in-graph replication.""" eval_thread = None if _TaskType.EVALUATOR in cluster_spec.jobs: eval_thread = threading.Thread( target=_run_single_worker, - args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0, + args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, session_config), kwargs={ "rpc_layer": rpc_layer, @@ -408,6 +412,8 @@ def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, # is the special task when we support cluster_spec propagation. def run_distribute_coordinator(worker_fn, strategy, + eval_fn=None, + eval_strategy=None, mode=CoordinatorMode.STANDALONE_CLIENT, cluster_spec=None, task_type=None, @@ -488,10 +494,12 @@ def run_distribute_coordinator(worker_fn, If `cluster_spec` is not given in any format, it becomes local training and this coordinator will connect to a local session. - For evaluation, if "evaluator" exist in the cluster_spec, a separate thread - will be created with its `task_type` set to "evaluator". If "evaluator" is not - set in the cluster_spec, it entirely depends on the `worker_fn` for how to do - evaluation. + For evaluation, if "evaluator" exists in the cluster_spec, a separate thread + will be created to call `eval_fn` with its `task_type` set to "evaluator". If + `eval_fn` is not defined, fall back to `worker_fn`. This implies that + evaluation will be done on a single machine if there is an "evaluator" task. + If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the + `worker_fn` for how to do evaluation. Args: worker_fn: the function to be called. The function should accept a @@ -501,6 +509,8 @@ def run_distribute_coordinator(worker_fn, run between-graph replicated training or not, whether to run init ops, etc. This object will also be configured given `session_config`, `cluster_spc`, `task_type` and `task_id`. + eval_fn: optional function for "evaluator" task. + eval_strategy: optional DistributionStrategy object for "evaluator" task. mode: in which mode this distribute coordinator runs. cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles in a cluster. If not set or empty, fall back to local training. @@ -535,16 +545,22 @@ def run_distribute_coordinator(worker_fn, # `mode` is ignored in the local case. _run_single_worker(worker_fn, strategy, None, None, None, session_config, rpc_layer) + if eval_fn: + _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None, + session_config, rpc_layer) elif mode == CoordinatorMode.STANDALONE_CLIENT: + eval_fn = eval_fn or worker_fn + eval_strategy = eval_strategy or strategy + # The client must know the cluster but servers in the cluster don't have to # know the client. if task_type in [_TaskType.CLIENT, None]: if strategy.between_graph: - _run_between_graph_client(worker_fn, strategy, cluster_spec, - session_config, rpc_layer) + _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer) else: - _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config, - rpc_layer) + _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, + cluster_spec, session_config, rpc_layer) else: # If not a client job, run the standard server. server = _run_std_server( @@ -554,6 +570,9 @@ def run_distribute_coordinator(worker_fn, if mode != CoordinatorMode.INDEPENDENT_WORKER: raise ValueError("Unexpected coordinator mode: %r" % mode) + eval_fn = eval_fn or worker_fn + eval_strategy = eval_strategy or strategy + # Every one starts a standard server. server = _run_std_server( cluster_spec=cluster_spec, task_type=task_type, task_id=task_id) @@ -572,8 +591,8 @@ def run_distribute_coordinator(worker_fn, else: server.join() elif task_type == _TaskType.EVALUATOR: - _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id, - session_config, rpc_layer) + _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type, + task_id, session_config, rpc_layer) else: if task_type != _TaskType.PS: raise ValueError("Unexpected task_type: %r" % task_type) diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py new file mode 100644 index 0000000000..202e19c420 --- /dev/null +++ b/tensorflow/python/distribute/estimator_training.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Training utilities for Estimator to use Distribute Coordinator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import six + +from tensorflow.python.distribute import distribute_coordinator as dc +from tensorflow.python.distribute import distribute_coordinator_context as dc_context +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import server_lib + +# pylint: disable=protected-access +CHIEF = dc._TaskType.CHIEF +EVALUATOR = dc._TaskType.EVALUATOR +PS = dc._TaskType.PS +WORKER = dc._TaskType.WORKER + +# pylint: enable=protected-access + + +def _count_ps(cluster_spec): + """Counts the number of parameter servers in cluster_spec.""" + if not cluster_spec: + raise RuntimeError( + 'Internal error: `_count_ps` does not expect empty cluster_spec.') + + return len(cluster_spec.as_dict().get(PS, [])) + + +def _count_worker(cluster_spec, chief_task_type): + """Counts the number of workers (including chief) in cluster_spec.""" + if not cluster_spec: + raise RuntimeError( + 'Internal error: `_count_worker` does not expect empty cluster_spec.') + + return (len(cluster_spec.as_dict().get(WORKER, [])) + len( + cluster_spec.as_dict().get(chief_task_type, []))) + + +def _get_global_id(cluster_spec, task_type, task_id, chief_task_type): + """Returns the global id of the given task type in a cluster.""" + if not task_type: + return 0 + + # Sort task names in cluster by "chief"/"master", "evaluator", "worker" + # and "ps". More details can be found at the documentation of + # @{tf.estimator.RunConfig.global_id_in_cluster}. + task_type_ordered_list = [] + if chief_task_type in cluster_spec.jobs: + task_type_ordered_list = [chief_task_type] + task_type_ordered_list.extend([ + t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS + ]) + if PS in cluster_spec.jobs: + task_type_ordered_list.append(PS) + + # Find the right gloabl_id for current task. + next_global_id = 0 + for t in task_type_ordered_list: + if t == task_type: + return next_global_id + task_id + # `cluster_spec.job_tasks` returns all task addresses of type `t`. + next_global_id += len(cluster_spec.job_tasks(t)) + + # It is unexpected that it passes through all task_types in + # `task_type_ordered_list`. + raise RuntimeError('Internal Error: `task_type` ({}) is not in ' + 'cluster_spec ({}).'.format(task_type, cluster_spec)) + + +def _init_run_config_from_worker_context(config, worker_context): + """Initializes run config from distribute coordinator's worker context.""" + + # pylint: disable=protected-access + config._service = None + config._cluster_spec = worker_context.cluster_spec + config._task_type = worker_context.task_type + config._task_id = worker_context.task_id + config._evaluation_master = worker_context.master_target + config._master = worker_context.master_target + config._is_chief = worker_context.is_chief + + if config._cluster_spec: + # Distributed mode. + if config._task_type != EVALUATOR: + + config._num_ps_replicas = _count_ps(config._cluster_spec) + config._num_worker_replicas = _count_worker( + config._cluster_spec, chief_task_type=CHIEF) + config._global_id_in_cluster = _get_global_id( + config._cluster_spec, + config._task_type, + config._task_id, + chief_task_type=CHIEF) + else: + # Evaluator task should not be aware of the other tasks. + config._cluster_spec = server_lib.ClusterSpec({}) + config._num_ps_replicas = 0 + config._num_worker_replicas = 0 + config._global_id_in_cluster = None # undefined + else: + # Local mode. + config._global_id_in_cluster = 0 + config._num_ps_replicas = 0 + config._num_worker_replicas = 1 + + +def init_run_config(config, tf_config): + """Initializes RunConfig for distribution strategies.""" + # pylint: disable=protected-access + if (config._experimental_distribute and + config._experimental_distribute.train_distribute): + if config._train_distribute: + raise ValueError('Either `train_distribute` or' + '`experimental_distribute.train_distribute` can be set.') + config._train_distribute = config._experimental_distribute.train_distribute + + if (config._experimental_distribute and + config._experimental_distribute.eval_distribute): + if config._eval_distribute: + raise ValueError('Either `eval_distribute` or' + '`experimental_distribute.eval_distribute` can be set.') + config._eval_distribute = config._experimental_distribute.eval_distribute + + cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {})) + config._init_distributed_setting_from_environment_var({}) + + # Use distribute coordinator with STANDALONE_CLIENT mode if + # `experimental_distribute.remote_cluster` is set. + if (config._train_distribute and config._experimental_distribute and + config._experimental_distribute.remote_cluster): + if tf_config: + raise ValueError('Cannot set both TF_CONFIG environment variable and ' + '`experimental_distribute.remote_cluster`') + config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT + config._cluster_spec = config._experimental_distribute.remote_cluster + logging.info('RunConfig initialized for Distribute Coordinator with ' + 'STANDALONE_CLIENT mode') + return + + # Don't use distribute coordinator if it is local training or cluster has a + # MASTER job or `train_distribute` is not specifed. + if (not tf_config or 'master' in cluster_spec.jobs or + not config._train_distribute): + config._distribute_coordinator_mode = None + config._init_distributed_setting_from_environment_var(tf_config) + config._maybe_overwrite_session_config_for_distributed_training() + logging.info('Not using Distribute Coordinator.') + return + + # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise. + assert tf_config + + # Set the cluster_spec only since the distributed setting will come from + # distribute coordinator. + config._cluster_spec = cluster_spec + config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER + logging.info('RunConfig initialized for Distribute Coordinator with ' + 'INDEPENDENT_WORKER mode') + + +def should_run_distribute_coordinator(config): + """Checks the config to see whether to run distribute coordinator.""" + # pylint: disable=protected-access + if (not hasattr(config, '_distribute_coordinator_mode') or + config._distribute_coordinator_mode is None): + return False + if (not isinstance(config._distribute_coordinator_mode, six.string_types) or + config._distribute_coordinator_mode not in [ + dc.CoordinatorMode.STANDALONE_CLIENT, + dc.CoordinatorMode.INDEPENDENT_WORKER + ]): + logging.warning('Unexpected distribute_coordinator_mode: %r', + config._distribute_coordinator_mode) + return False + if not config.cluster_spec: + logging.warning('Running `train_and_evaluate` locally, ignoring ' + '`experimental_distribute_coordinator_mode`.') + return False + return True + + +def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): + """Run distribute coordinator for Estimator's `train_and_evaluate`. + + Args: + estimator: An `Estimator` instance to train and evaluate. + train_spec: A `TrainSpec` instance to specify the training specification. + eval_spec: A `EvalSpec` instance to specify the evaluation and export + specification. + executor_cls: the evaluation executor class of Estimator. + + Raises: + ValueError: if `distribute_coordinator_mode` is None in RunConfig. + """ + run_config = estimator.config + if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access + raise ValueError( + 'Distribute coordinator mode is not specified in `RunConfig`.') + + def _worker_fn(strategy): + """Function for worker task.""" + local_estimator = copy.deepcopy(estimator) + # pylint: disable=protected-access + local_estimator._config._train_distribute = strategy + _init_run_config_from_worker_context( + local_estimator._config, dc_context.get_current_worker_context()) + local_estimator._train_distribution = strategy + # pylint: enable=protected-access + + local_estimator.train( + input_fn=train_spec.input_fn, + max_steps=train_spec.max_steps, + hooks=list(train_spec.hooks)) + + def _eval_fn(strategy): + """Function for evaluator task.""" + local_estimator = copy.deepcopy(estimator) + # pylint: disable=protected-access + local_estimator._config._eval_distribute = strategy + _init_run_config_from_worker_context( + local_estimator._config, dc_context.get_current_worker_context()) + local_estimator._eval_distribution = strategy + + executor = executor_cls(local_estimator, train_spec, eval_spec) + executor._start_continuous_evaluation() + # pylint: enable=protected-access + + # pylint: disable=protected-access + if (run_config._distribute_coordinator_mode == + dc.CoordinatorMode.STANDALONE_CLIENT): + cluster_spec = run_config.cluster_spec + assert cluster_spec + else: + # The cluster_spec comes from TF_CONFIG environment variable if it is + # INDEPENDENT_WORKER mode. + cluster_spec = None + + dc.run_distribute_coordinator( + _worker_fn, + run_config.train_distribute, + _eval_fn, + run_config.eval_distribute, + mode=run_config._distribute_coordinator_mode, + cluster_spec=cluster_spec, + session_config=run_config.session_config) |