aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-24 20:49:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 20:52:49 -0700
commit04ffe2f34957f02d5a2aa4ead1c75233dd1cb1b7 (patch)
treeb2e21643282628f005d758e617d75e6cd3af071c
parentca94990804cf5326c0f6f46d75c96e0f0e240366 (diff)
Add environment and rpc_layer to the TF_CONFIG environment variable in distribute coordinator.
PiperOrigin-RevId: 210197404
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py65
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py64
2 files changed, 117 insertions, 12 deletions
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 9cf0b3b7a6..46cdd64a6e 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -22,9 +22,12 @@ import copy
import json
import os
import threading
+import time
from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -332,16 +335,38 @@ def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
- rpc_layer=None):
+ rpc_layer=None,
+ environment=None):
"""Runs a standard server."""
- server = server_lib.Server(
- cluster_spec,
- job_name=task_type,
- task_index=task_id,
- config=session_config,
- protocol=rpc_layer)
- server.start()
- return server
+
+ class _FakeServer(object):
+ """A fake server that runs a master session."""
+
+ def start(self):
+ assert cluster_spec
+ target = cluster_spec.task_address(task_type, task_id)
+ if rpc_layer:
+ target = rpc_layer + "://" + target
+ # A tensorflow server starts when a remote session is created.
+ session.Session(target=target, config=session_config)
+
+ def join(self):
+ while True:
+ time.sleep(5)
+
+ if environment == "google":
+ server = _FakeServer()
+ server.start()
+ return server
+ else:
+ server = server_lib.Server(
+ cluster_spec,
+ job_name=task_type,
+ task_index=task_id,
+ config=session_config,
+ protocol=rpc_layer)
+ server.start()
+ return server
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
@@ -541,8 +566,18 @@ def run_distribute_coordinator(worker_fn,
"`tf.train.ClusterDef` object")
# TODO(yuefengz): validate cluster_spec.
+ rpc_layer = tf_config.get("rpc_layer", rpc_layer)
+ environment = tf_config.get("environment", None)
+
+ if cluster_spec:
+ logging.info(
+ "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
+ "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
+ cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
+
if not cluster_spec:
# `mode` is ignored in the local case.
+ logging.info("Running local Distribute Coordinator.")
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
if eval_fn:
@@ -564,7 +599,11 @@ def run_distribute_coordinator(worker_fn,
else:
# If not a client job, run the standard server.
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ rpc_layer=rpc_layer,
+ environment=environment)
server.join()
else:
if mode != CoordinatorMode.INDEPENDENT_WORKER:
@@ -575,7 +614,11 @@ def run_distribute_coordinator(worker_fn,
# Every one starts a standard server.
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ rpc_layer=rpc_layer,
+ environment=environment)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if strategy.between_graph:
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 97c6bdd15a..5dd57fa134 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import contextlib
import copy
+import json
import os
import sys
+import time
import threading
import six
@@ -59,6 +61,8 @@ INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
NUM_WORKERS = 3
NUM_PS = 2
+original_sys_exit = sys.exit
+
def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
@@ -369,7 +373,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec=None,
task_type=None,
task_id=None,
- rpc_layer=None):
+ rpc_layer=None,
+ environment=None):
task_type = str(task_type)
task_id = task_id or 0
with self._lock:
@@ -730,6 +735,63 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
+ def testRunStdServerInGoogleEnvironment(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec, "environment": "google"}
+
+ joined = [False]
+
+ def _fake_sleep(_):
+ joined[0] = True
+ original_sys_exit(0)
+
+ def _thread_fn(cluster_spec):
+ distribute_coordinator.run_distribute_coordinator(
+ None,
+ None,
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="ps",
+ task_id=0)
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ time, "sleep", _fake_sleep):
+ t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
+ t.start()
+ t.join()
+ self.assertTrue(joined[0])
+
+ def testRpcLayerEnvironmentVariable(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}
+
+ rpc_layer_from_coordinator = [None]
+
+ def _run_mock_server(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ session_config=None,
+ rpc_layer=None,
+ environment=None):
+ del cluster_spec, task_type, task_id, session_config, environment
+ rpc_layer_from_coordinator[0] = rpc_layer
+ return MockServer()
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server", _run_mock_server):
+ distribute_coordinator.run_distribute_coordinator(
+ None,
+ None,
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="ps",
+ task_id=0)
+ self.assertEqual(rpc_layer_from_coordinator[0], "cake")
+
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.