aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-30 18:20:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 18:25:46 -0700
commitab96371eaea4cc5b2f9c431eec455a1cf4be7c1c (patch)
tree131229c608f5413844a3a31ba6a7bf5f1a581a32 /tensorflow/python/distribute
parent970ec898689b7957eb45edfa1e55b727bec7976e (diff)
Set session configs for ParameterServerStrategy and CollectiveAllReduceStrategy in their configure methods.
Configure the session configs in DistributeCoordinator. Allow eval strategy to be None in DistributeCoordinator. Add more loggings. PiperOrigin-RevId: 211017587
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/BUILD4
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py140
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py101
3 files changed, 217 insertions, 28 deletions
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index a081c30781..bdc869c643 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -34,7 +34,11 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":distribute_coordinator_context",
+ ":multi_worker_util",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:session",
"//tensorflow/python:training",
],
)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 46cdd64a6e..d9f78150b9 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -24,9 +24,10 @@ import os
import threading
import time
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -238,19 +239,26 @@ class _WorkerContext(object):
Returns:
a descendant of SessionCreator.
"""
- # TODO(yuefengz): merge session config.
- if self._strategy.should_init:
+ if config:
+ session_config = copy.deepcopy(config)
+ session_config.MergeFrom(self._session_config)
+ else:
+ session_config = self._session_config
+
+ if not self._strategy or self._strategy.should_init:
+ logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
master=self.master_target,
- config=config or self._session_config,
+ config=session_config,
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=checkpoint_filename_with_path)
else:
+ logging.info("Creating worker session creator with config: %r", config)
return monitored_session.WorkerSessionCreator(
scaffold,
master=self.master_target,
- config=config or self._session_config,
+ config=session_config,
max_wait_secs=max_wait_secs)
@property
@@ -313,12 +321,17 @@ def _run_single_worker(worker_fn,
rpc_layer="",
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
+ session_config = copy.deepcopy(session_config)
strategy = copy.deepcopy(strategy)
# If there is an EVALUATOR task, we run single-machine eval on that task.
if task_type == _TaskType.EVALUATOR:
- strategy.configure(session_config)
+ # It is possible to not have a strategy object for EVALUATOR task.
+ if strategy:
+ strategy.configure(session_config)
else:
+ assert strategy
strategy.configure(session_config, cluster_spec, task_type, task_id)
+
context = _WorkerContext(
strategy,
cluster_spec,
@@ -331,6 +344,25 @@ def _run_single_worker(worker_fn,
worker_fn(strategy)
+def _split_cluster_for_evaluator(cluster_spec, task_type):
+ """Split the cluster for evaluator since it needn't talk to other tasks."""
+ # Splitting the cluster is important to prevent the evaluator from talking to
+ # other tasks in the cluster. Since we allow evaluator not to use
+ # distribution strategies and as a result ops in the evalauator task may have
+ # unspecified devices. Those ops may end up on other tasks if we don't split
+ # the cluster.
+ new_cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_spec).as_dict()
+ if task_type == _TaskType.EVALUATOR:
+ assert _TaskType.EVALUATOR in new_cluster_spec
+ new_cluster_spec = {
+ _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
+ }
+ else:
+ new_cluster_spec.pop(_TaskType.EVALUATOR, None)
+ return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
+
+
def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
@@ -338,16 +370,19 @@ def _run_std_server(cluster_spec=None,
rpc_layer=None,
environment=None):
"""Runs a standard server."""
+ assert cluster_spec
+ target = cluster_spec.task_address(task_type, task_id)
+ if rpc_layer:
+ target = rpc_layer + "://" + target
class _FakeServer(object):
"""A fake server that runs a master session."""
def start(self):
- assert cluster_spec
- target = cluster_spec.task_address(task_type, task_id)
- if rpc_layer:
- target = rpc_layer + "://" + target
# A tensorflow server starts when a remote session is created.
+ logging.info(
+ "Creating a remote session to start a TensorFlow server, "
+ "target = %r, session_config=%r", target, session_config)
session.Session(target=target, config=session_config)
def join(self):
@@ -359,6 +394,13 @@ def _run_std_server(cluster_spec=None,
server.start()
return server
else:
+ if session_config:
+ logging.info(
+ "Starting standard TensorFlow server, target = %r, session_config= "
+ "%r", target, session_config)
+ else:
+ logging.info("Starting standard TensorFlow server, target = %r", target)
+ cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
server = server_lib.Server(
cluster_spec,
job_name=task_type,
@@ -376,7 +418,7 @@ def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
+ args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@@ -432,6 +474,33 @@ def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
if eval_thread:
eval_thread.join()
+
+def _configure_session_config_for_std_servers(
+ strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
+ # pylint: disable=g-doc-args
+ """Call strategy's `configure` to mutate the session_config.
+
+ The session_config is currently needed as default config for a TensorFlow
+ server. In the future, we should be able to remove this method and only pass
+ the session config to a client session.
+ """
+ if task_type == _TaskType.EVALUATOR:
+ if eval_strategy:
+ eval_strategy.configure(session_config=session_config)
+ else:
+ # The strategy may be shared in standalone client mode.
+ strategy = copy.deepcopy(strategy)
+ strategy.configure(
+ session_config=session_config,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ # Remove the device filters specific to the strategy, so that the
+ # TensorFlow server brought up with one strategy can be used by other
+ # strategies. The device filters can be set in the client side as well.
+ del session_config.device_filters[:]
+
+
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
# TODO(yuefengz): we may need a smart way to figure out whether the current task
# is the special task when we support cluster_spec propagation.
@@ -533,8 +602,10 @@ def run_distribute_coordinator(worker_fn,
strategy: a DistributionStrategy object which specifying whether it should
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
- `cluster_spc`, `task_type` and `task_id`.
- eval_fn: optional function for "evaluator" task.
+ `cluster_spec`, `task_type` and `task_id`.
+ eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
+ in but a "evaluator" task found in the `cluster_spec`, the `worker_fn`
+ will be used for this task.
eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
@@ -558,17 +629,17 @@ def run_distribute_coordinator(worker_fn,
task_id = int(task_env.get("index", task_id))
if cluster_spec:
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
# TODO(yuefengz): validate cluster_spec.
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
environment = tf_config.get("environment", None)
+ # Setting the session config is necessary for some strategies such
+ # CollectiveAllReduceStrategy.
+ session_config = session_config or config_pb2.ConfigProto(
+ allow_soft_placement=True)
+
if cluster_spec:
logging.info(
"Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
@@ -581,11 +652,18 @@ def run_distribute_coordinator(worker_fn,
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
- _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
+ _run_single_worker(eval_fn, eval_strategy, None, None, None,
session_config, rpc_layer)
+ else:
+ logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
elif mode == CoordinatorMode.STANDALONE_CLIENT:
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
eval_fn = eval_fn or worker_fn
- eval_strategy = eval_strategy or strategy
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
# The client must know the cluster but servers in the cluster don't have to
# know the client.
@@ -598,10 +676,14 @@ def run_distribute_coordinator(worker_fn,
cluster_spec, session_config, rpc_layer)
else:
# If not a client job, run the standard server.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
environment=environment)
server.join()
@@ -609,14 +691,24 @@ def run_distribute_coordinator(worker_fn,
if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode)
+ if not eval_fn:
+ logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
+ "used if an \"evaluator\" task exists in the cluster.")
eval_fn = eval_fn or worker_fn
- eval_strategy = eval_strategy or strategy
-
- # Every one starts a standard server.
+ if not eval_strategy:
+ logging.warning("`eval_strategy` is not passed in. No distribution "
+ "strategy will be used for evaluation.")
+
+ # Every one starts a standard server, get session config from `configure`
+ # method.
+ _configure_session_config_for_std_servers(strategy, eval_strategy,
+ session_config, cluster_spec,
+ task_type, task_id)
server = _run_std_server(
cluster_spec=cluster_spec,
task_type=task_type,
task_id=task_id,
+ session_config=session_config,
rpc_layer=rpc_layer,
environment=environment)
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 5dd57fa134..ac5dd569ed 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -96,11 +96,10 @@ class MockStrategy(object):
return self._between_graph
def configure(self,
- session_options=None,
+ session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
- del session_options, cluster_spec, task_type
if self._should_init is None:
if task_id == 0:
self._should_init = True
@@ -117,6 +116,17 @@ class MockStrategy(object):
else:
self._should_save_summary = False
+ if session_config:
+ if (cluster_spec and task_type and task_id is not None and
+ self._between_graph):
+ session_config.intra_op_parallelism_threads += 1
+ if task_type in ["chief", "worker"]:
+ session_config.device_filters.extend(
+ ["/job:%s/task:%d" % (task_type, task_id), "/job:ps"])
+ else:
+ session_config.inter_op_parallelism_threads += 1
+ session_config.device_filters.append("/job:somejob")
+
@property
def should_init(self):
return self._should_init
@@ -748,7 +758,7 @@ class DistributeCoordinatorTestInpendentWorkerMode(
def _thread_fn(cluster_spec):
distribute_coordinator.run_distribute_coordinator(
None,
- None,
+ MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
@@ -785,7 +795,7 @@ class DistributeCoordinatorTestInpendentWorkerMode(
distribute_coordinator, "_run_std_server", _run_mock_server):
distribute_coordinator.run_distribute_coordinator(
None,
- None,
+ MockStrategy(between_graph=True),
mode=INDEPENDENT_WORKER,
cluster_spec=cluster_spec,
task_type="ps",
@@ -793,6 +803,89 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertEqual(rpc_layer_from_coordinator[0], "cake")
+class StrategyConfigureTest(test.TestCase):
+
+ def setUp(self):
+ self._device_filters = []
+ self._intra_op_parallelism_threads = None
+ self._inter_op_parallelism_threads = None
+ super(StrategyConfigureTest, self).setUp()
+
+ def _dump_device_filters(self, *args, **kwargs):
+ session_config = kwargs.get("session_config", None)
+ self._device_filters.extend(session_config.device_filters)
+ self._intra_op_parallelism_threads = (
+ session_config.intra_op_parallelism_threads)
+ self._inter_op_parallelism_threads = (
+ session_config.inter_op_parallelism_threads)
+ return MockServer()
+
+ def _worker_fn(self, strategy):
+ worker_context = distribute_coordinator_context.get_current_worker_context()
+ session_config = worker_context._session_config
+ self._device_filters.extend(session_config.device_filters)
+ self._intra_op_parallelism_threads = (
+ session_config.intra_op_parallelism_threads)
+ self._inter_op_parallelism_threads = (
+ session_config.inter_op_parallelism_threads)
+ return MockServer()
+
+ def test_session_config_in_std_server(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server",
+ self._dump_device_filters):
+ distribute_coordinator.run_distribute_coordinator(
+ lambda _: None,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="worker",
+ task_id=0)
+ self.assertEqual(self._intra_op_parallelism_threads, 1)
+ self.assertEqual(self._inter_op_parallelism_threads, 0)
+
+ def test_session_config_in_session_creator(self):
+ cluster_spec = {"worker": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ distribute_coordinator.run_distribute_coordinator(
+ self._worker_fn,
+ MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="worker",
+ task_id=0)
+ self.assertEqual(self._device_filters, ["/job:worker/task:0", "/job:ps"])
+ self.assertEqual(self._intra_op_parallelism_threads, 2)
+ self.assertEqual(self._inter_op_parallelism_threads, 0)
+
+ def test_eval_strategy_configure(self):
+ cluster_spec = {"evaluator": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec}
+
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ distribute_coordinator.run_distribute_coordinator(
+ lambda _: None,
+ MockStrategy(between_graph=False),
+ eval_fn=self._worker_fn,
+ eval_strategy=MockStrategy(between_graph=True),
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="evaluator",
+ task_id=0)
+ self.assertEqual(self._device_filters, ["/job:somejob"])
+ self.assertEqual(self._intra_op_parallelism_threads, 0)
+ self.assertEqual(self._inter_op_parallelism_threads, 2)
+
+
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.
with test.mock.patch.object(sys, "exit", os._exit):