aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-24 19:14:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 19:21:07 -0700
commitca94990804cf5326c0f6f46d75c96e0f0e240366 (patch)
treeddb7499d3404d1c9f08b8d2364053111101a677f /tensorflow/python/distribute
parent9599b473035fb9f38959608f32180e22216c7dbc (diff)
Add an option to RunConfig and train_and_evaluate to run distribute coordinator.
This is necessary to run multi-worker MirroredStrategy and CollectiveAllReduceStrategy with estimator. PiperOrigin-RevId: 210192378
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/BUILD33
-rw-r--r--tensorflow/python/distribute/distribute_config.py45
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py53
-rw-r--r--tensorflow/python/distribute/estimator_training.py264
4 files changed, 378 insertions, 17 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 98ef9bf492..ebfcd085e6 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -9,6 +9,25 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
+ name = "distribute",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":distribute_config",
+ ":distribute_coordinator",
+ ":distribute_coordinator_context",
+ ],
+)
+
+py_library(
+ name = "distribute_config",
+ srcs = [
+ "distribute_config.py",
+ ],
+ deps = [],
+)
+
+py_library(
name = "distribute_coordinator",
srcs = [
"distribute_coordinator.py",
@@ -81,3 +100,17 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
+
+# Used only by estimator.
+py_library(
+ name = "estimator_training",
+ srcs = [
+ "estimator_training.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":distribute_coordinator",
+ ":distribute_coordinator_context",
+ "//tensorflow/python:training",
+ ],
+)
diff --git a/tensorflow/python/distribute/distribute_config.py b/tensorflow/python/distribute/distribute_config.py
new file mode 100644
index 0000000000..fac35742fe
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_config.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A configure tuple for high-level APIs for running distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+class DistributeConfig(
+ collections.namedtuple(
+ 'DistributeConfig',
+ ['train_distribute', 'eval_distribute', 'remote_cluster'])):
+ """A config tuple for distribution strategies.
+
+ Attributes:
+ train_distribute: a `DistributionStrategy` object for training.
+ eval_distribute: an optional `DistributionStrategy` object for
+ evaluation.
+ remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying
+ the cluster configurations. If this is given, the `train_and_evaluate`
+ method will be running as a standalone client which connects to the
+ cluster for training.
+ """
+
+ def __new__(cls,
+ train_distribute=None,
+ eval_distribute=None,
+ remote_cluster=None):
+ return super(DistributeConfig, cls).__new__(cls, train_distribute,
+ eval_distribute, remote_cluster)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index eb081b65fc..9cf0b3b7a6 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -311,7 +311,11 @@ def _run_single_worker(worker_fn,
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
strategy = copy.deepcopy(strategy)
- strategy.configure(session_config, cluster_spec, task_type, task_id)
+ # If there is an EVALUATOR task, we run single-machine eval on that task.
+ if task_type == _TaskType.EVALUATOR:
+ strategy.configure(session_config)
+ else:
+ strategy.configure(session_config, cluster_spec, task_type, task_id)
context = _WorkerContext(
strategy,
cluster_spec,
@@ -340,14 +344,14 @@ def _run_std_server(cluster_spec=None,
return server
-def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
- rpc_layer):
+def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for between-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@@ -378,14 +382,14 @@ def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
eval_thread.join()
-def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
- rpc_layer):
+def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for in-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@@ -408,6 +412,8 @@ def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
strategy,
+ eval_fn=None,
+ eval_strategy=None,
mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
task_type=None,
@@ -488,10 +494,12 @@ def run_distribute_coordinator(worker_fn,
If `cluster_spec` is not given in any format, it becomes local training and
this coordinator will connect to a local session.
- For evaluation, if "evaluator" exist in the cluster_spec, a separate thread
- will be created with its `task_type` set to "evaluator". If "evaluator" is not
- set in the cluster_spec, it entirely depends on the `worker_fn` for how to do
- evaluation.
+ For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
+ will be created to call `eval_fn` with its `task_type` set to "evaluator". If
+ `eval_fn` is not defined, fall back to `worker_fn`. This implies that
+ evaluation will be done on a single machine if there is an "evaluator" task.
+ If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the
+ `worker_fn` for how to do evaluation.
Args:
worker_fn: the function to be called. The function should accept a
@@ -501,6 +509,8 @@ def run_distribute_coordinator(worker_fn,
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
`cluster_spc`, `task_type` and `task_id`.
+ eval_fn: optional function for "evaluator" task.
+ eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
@@ -535,16 +545,22 @@ def run_distribute_coordinator(worker_fn,
# `mode` is ignored in the local case.
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
+ if eval_fn:
+ _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
+ session_config, rpc_layer)
elif mode == CoordinatorMode.STANDALONE_CLIENT:
+ eval_fn = eval_fn or worker_fn
+ eval_strategy = eval_strategy or strategy
+
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
if strategy.between_graph:
- _run_between_graph_client(worker_fn, strategy, cluster_spec,
- session_config, rpc_layer)
+ _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer)
else:
- _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
- rpc_layer)
+ _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer)
else:
# If not a client job, run the standard server.
server = _run_std_server(
@@ -554,6 +570,9 @@ def run_distribute_coordinator(worker_fn,
if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode)
+ eval_fn = eval_fn or worker_fn
+ eval_strategy = eval_strategy or strategy
+
# Every one starts a standard server.
server = _run_std_server(
cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
@@ -572,8 +591,8 @@ def run_distribute_coordinator(worker_fn,
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
- _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
- session_config, rpc_layer)
+ _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
new file mode 100644
index 0000000000..202e19c420
--- /dev/null
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -0,0 +1,264 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Training utilities for Estimator to use Distribute Coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import six
+
+from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import distribute_coordinator_context as dc_context
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import server_lib
+
+# pylint: disable=protected-access
+CHIEF = dc._TaskType.CHIEF
+EVALUATOR = dc._TaskType.EVALUATOR
+PS = dc._TaskType.PS
+WORKER = dc._TaskType.WORKER
+
+# pylint: enable=protected-access
+
+
+def _count_ps(cluster_spec):
+ """Counts the number of parameter servers in cluster_spec."""
+ if not cluster_spec:
+ raise RuntimeError(
+ 'Internal error: `_count_ps` does not expect empty cluster_spec.')
+
+ return len(cluster_spec.as_dict().get(PS, []))
+
+
+def _count_worker(cluster_spec, chief_task_type):
+ """Counts the number of workers (including chief) in cluster_spec."""
+ if not cluster_spec:
+ raise RuntimeError(
+ 'Internal error: `_count_worker` does not expect empty cluster_spec.')
+
+ return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
+ cluster_spec.as_dict().get(chief_task_type, [])))
+
+
+def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
+ """Returns the global id of the given task type in a cluster."""
+ if not task_type:
+ return 0
+
+ # Sort task names in cluster by "chief"/"master", "evaluator", "worker"
+ # and "ps". More details can be found at the documentation of
+ # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ task_type_ordered_list = []
+ if chief_task_type in cluster_spec.jobs:
+ task_type_ordered_list = [chief_task_type]
+ task_type_ordered_list.extend([
+ t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
+ ])
+ if PS in cluster_spec.jobs:
+ task_type_ordered_list.append(PS)
+
+ # Find the right gloabl_id for current task.
+ next_global_id = 0
+ for t in task_type_ordered_list:
+ if t == task_type:
+ return next_global_id + task_id
+ # `cluster_spec.job_tasks` returns all task addresses of type `t`.
+ next_global_id += len(cluster_spec.job_tasks(t))
+
+ # It is unexpected that it passes through all task_types in
+ # `task_type_ordered_list`.
+ raise RuntimeError('Internal Error: `task_type` ({}) is not in '
+ 'cluster_spec ({}).'.format(task_type, cluster_spec))
+
+
+def _init_run_config_from_worker_context(config, worker_context):
+ """Initializes run config from distribute coordinator's worker context."""
+
+ # pylint: disable=protected-access
+ config._service = None
+ config._cluster_spec = worker_context.cluster_spec
+ config._task_type = worker_context.task_type
+ config._task_id = worker_context.task_id
+ config._evaluation_master = worker_context.master_target
+ config._master = worker_context.master_target
+ config._is_chief = worker_context.is_chief
+
+ if config._cluster_spec:
+ # Distributed mode.
+ if config._task_type != EVALUATOR:
+
+ config._num_ps_replicas = _count_ps(config._cluster_spec)
+ config._num_worker_replicas = _count_worker(
+ config._cluster_spec, chief_task_type=CHIEF)
+ config._global_id_in_cluster = _get_global_id(
+ config._cluster_spec,
+ config._task_type,
+ config._task_id,
+ chief_task_type=CHIEF)
+ else:
+ # Evaluator task should not be aware of the other tasks.
+ config._cluster_spec = server_lib.ClusterSpec({})
+ config._num_ps_replicas = 0
+ config._num_worker_replicas = 0
+ config._global_id_in_cluster = None # undefined
+ else:
+ # Local mode.
+ config._global_id_in_cluster = 0
+ config._num_ps_replicas = 0
+ config._num_worker_replicas = 1
+
+
+def init_run_config(config, tf_config):
+ """Initializes RunConfig for distribution strategies."""
+ # pylint: disable=protected-access
+ if (config._experimental_distribute and
+ config._experimental_distribute.train_distribute):
+ if config._train_distribute:
+ raise ValueError('Either `train_distribute` or'
+ '`experimental_distribute.train_distribute` can be set.')
+ config._train_distribute = config._experimental_distribute.train_distribute
+
+ if (config._experimental_distribute and
+ config._experimental_distribute.eval_distribute):
+ if config._eval_distribute:
+ raise ValueError('Either `eval_distribute` or'
+ '`experimental_distribute.eval_distribute` can be set.')
+ config._eval_distribute = config._experimental_distribute.eval_distribute
+
+ cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
+ config._init_distributed_setting_from_environment_var({})
+
+ # Use distribute coordinator with STANDALONE_CLIENT mode if
+ # `experimental_distribute.remote_cluster` is set.
+ if (config._train_distribute and config._experimental_distribute and
+ config._experimental_distribute.remote_cluster):
+ if tf_config:
+ raise ValueError('Cannot set both TF_CONFIG environment variable and '
+ '`experimental_distribute.remote_cluster`')
+ config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
+ config._cluster_spec = config._experimental_distribute.remote_cluster
+ logging.info('RunConfig initialized for Distribute Coordinator with '
+ 'STANDALONE_CLIENT mode')
+ return
+
+ # Don't use distribute coordinator if it is local training or cluster has a
+ # MASTER job or `train_distribute` is not specifed.
+ if (not tf_config or 'master' in cluster_spec.jobs or
+ not config._train_distribute):
+ config._distribute_coordinator_mode = None
+ config._init_distributed_setting_from_environment_var(tf_config)
+ config._maybe_overwrite_session_config_for_distributed_training()
+ logging.info('Not using Distribute Coordinator.')
+ return
+
+ # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
+ assert tf_config
+
+ # Set the cluster_spec only since the distributed setting will come from
+ # distribute coordinator.
+ config._cluster_spec = cluster_spec
+ config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
+ logging.info('RunConfig initialized for Distribute Coordinator with '
+ 'INDEPENDENT_WORKER mode')
+
+
+def should_run_distribute_coordinator(config):
+ """Checks the config to see whether to run distribute coordinator."""
+ # pylint: disable=protected-access
+ if (not hasattr(config, '_distribute_coordinator_mode') or
+ config._distribute_coordinator_mode is None):
+ return False
+ if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
+ config._distribute_coordinator_mode not in [
+ dc.CoordinatorMode.STANDALONE_CLIENT,
+ dc.CoordinatorMode.INDEPENDENT_WORKER
+ ]):
+ logging.warning('Unexpected distribute_coordinator_mode: %r',
+ config._distribute_coordinator_mode)
+ return False
+ if not config.cluster_spec:
+ logging.warning('Running `train_and_evaluate` locally, ignoring '
+ '`experimental_distribute_coordinator_mode`.')
+ return False
+ return True
+
+
+def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
+ """Run distribute coordinator for Estimator's `train_and_evaluate`.
+
+ Args:
+ estimator: An `Estimator` instance to train and evaluate.
+ train_spec: A `TrainSpec` instance to specify the training specification.
+ eval_spec: A `EvalSpec` instance to specify the evaluation and export
+ specification.
+ executor_cls: the evaluation executor class of Estimator.
+
+ Raises:
+ ValueError: if `distribute_coordinator_mode` is None in RunConfig.
+ """
+ run_config = estimator.config
+ if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
+ raise ValueError(
+ 'Distribute coordinator mode is not specified in `RunConfig`.')
+
+ def _worker_fn(strategy):
+ """Function for worker task."""
+ local_estimator = copy.deepcopy(estimator)
+ # pylint: disable=protected-access
+ local_estimator._config._train_distribute = strategy
+ _init_run_config_from_worker_context(
+ local_estimator._config, dc_context.get_current_worker_context())
+ local_estimator._train_distribution = strategy
+ # pylint: enable=protected-access
+
+ local_estimator.train(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks))
+
+ def _eval_fn(strategy):
+ """Function for evaluator task."""
+ local_estimator = copy.deepcopy(estimator)
+ # pylint: disable=protected-access
+ local_estimator._config._eval_distribute = strategy
+ _init_run_config_from_worker_context(
+ local_estimator._config, dc_context.get_current_worker_context())
+ local_estimator._eval_distribution = strategy
+
+ executor = executor_cls(local_estimator, train_spec, eval_spec)
+ executor._start_continuous_evaluation()
+ # pylint: enable=protected-access
+
+ # pylint: disable=protected-access
+ if (run_config._distribute_coordinator_mode ==
+ dc.CoordinatorMode.STANDALONE_CLIENT):
+ cluster_spec = run_config.cluster_spec
+ assert cluster_spec
+ else:
+ # The cluster_spec comes from TF_CONFIG environment variable if it is
+ # INDEPENDENT_WORKER mode.
+ cluster_spec = None
+
+ dc.run_distribute_coordinator(
+ _worker_fn,
+ run_config.train_distribute,
+ _eval_fn,
+ run_config.eval_distribute,
+ mode=run_config._distribute_coordinator_mode,
+ cluster_spec=cluster_spec,
+ session_config=run_config.session_config)