aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-09-19 10:43:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 10:47:35 -0700
commit428f7037bef6dbfdd01a4283a6c76221d381ef7e (patch)
tree5feb498bd4ee6ba4d2438de2a11b446763dd47d2 /tensorflow/contrib/distribute
parent5d5bc6d2b592374d7862cdebbc53e07b47e29c95 (diff)
Fix estimator_training test flakiness.
PiperOrigin-RevId: 213653403
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD3
-rw-r--r--tensorflow/contrib/distribute/python/estimator_training_test.py248
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py53
3 files changed, 110 insertions, 194 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index f72b827e04..ebea512c04 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -472,11 +472,8 @@ cuda_py_test(
"//tensorflow/python:summary",
],
tags = [
- "manual",
"multi_and_single_gpu",
"no_pip",
- "nogpu",
- "notap",
],
)
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
index 5348512016..157618f72f 100644
--- a/tensorflow/contrib/distribute/python/estimator_training_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -26,21 +26,12 @@ import tempfile
import threading
from absl.testing import parameterized
import numpy as np
-import six
-_portpicker_import_error = None
-try:
- import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error: # pylint: disable=invalid-name
- _portpicker_import_error = _error
- portpicker = None
-
-# pylint: disable=g-import-not-at-top
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.optimizer_v2 import adagrad
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import estimator_training as dc_training
@@ -57,7 +48,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import server_lib
BATCH_SIZE = 10
LABEL_DIMENSION = 2
@@ -73,130 +63,38 @@ EVALUATOR = dc._TaskType.EVALUATOR
WORKER = dc._TaskType.WORKER
PS = dc._TaskType.PS
-original_run_distribute_coordinator = dc.run_distribute_coordinator
-
-
-# TODO(yuefengz): merge this method back to test_util.
-def _create_local_cluster(num_workers,
- num_ps,
- has_eval=False,
- protocol="grpc",
- worker_config=None,
- ps_config=None):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
-
- cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
- }
- if has_eval:
- cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
-
- cs = server_lib.ClusterSpec(cluster_dict)
-
- workers = [
- server_lib.Server(
- cs,
- job_name="worker",
- protocol=protocol,
- task_index=ix,
- config=worker_config,
- start=True) for ix in range(num_workers)
- ]
- ps_servers = [
- server_lib.Server(
- cs,
- job_name="ps",
- protocol=protocol,
- task_index=ix,
- config=ps_config,
- start=True) for ix in range(num_ps)
- ]
- if has_eval:
- evals = [
- server_lib.Server(
- cs,
- job_name="evaluator",
- protocol=protocol,
- task_index=0,
- config=worker_config,
- start=True)
- ]
- else:
- evals = []
-
- return workers, ps_servers, evals
-
-
-def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
- """Create an in-process cluster that consists of only standard server."""
- # Leave some memory for cuda runtime.
- if has_eval:
- gpu_mem_frac = 0.7 / (num_workers + 1)
- else:
- gpu_mem_frac = 0.7 / num_workers
-
- worker_config = config_pb2.ConfigProto()
- worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
-
- # Enable collective ops which has no impact on non-collective ops.
- # TODO(yuefengz, tucker): removing this after we move the initialization of
- # collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- "/job:worker/replica:0/task:0")
-
- ps_config = config_pb2.ConfigProto()
- ps_config.device_count["GPU"] = 0
-
- return _create_local_cluster(
- num_workers,
- num_ps=num_ps,
- has_eval=has_eval,
- worker_config=worker_config,
- ps_config=ps_config,
- protocol="grpc")
-
-
-def _create_cluster_spec(has_chief=False,
- num_workers=1,
- num_ps=0,
- has_eval=False):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
-
- cluster_spec = {}
- if has_chief:
- cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
- if num_workers:
- cluster_spec[WORKER] = [
- "localhost:%s" % portpicker.pick_unused_port()
- for _ in range(num_workers)
- ]
- if num_ps:
- cluster_spec[PS] = [
- "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
- ]
- if has_eval:
- cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
- return cluster_spec
+original_run_std_server = dc._run_std_server
-def _bytes_to_str(maybe_bytes):
- if isinstance(maybe_bytes, six.string_types):
- return maybe_bytes
- else:
- return str(maybe_bytes, "utf-8")
+class MockOsEnv(dict):
+
+ def __init__(self, *args):
+ self._thread_local = threading.local()
+ super(MockOsEnv, self).__init__(*args)
+
+ def get(self, key, default):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.get(self._thread_local.dict, key, default)
+ else:
+ return dict.get(self, key, default)
+ def __getitem__(self, key):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__getitem__(self._thread_local.dict, key)
+ else:
+ return dict.__getitem__(self, key)
-def _strip_protocol(target):
- # cluster_spec expects "host:port" strings.
- if "//" in target:
- return target.split("//")[1]
- else:
- return target
+ def __setitem__(self, key, val):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__setitem__(self._thread_local.dict, key, val)
+ else:
+ return dict.__setitem__(self, key, val)
class DistributeCoordinatorIntegrationTest(test.TestCase,
@@ -205,22 +103,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=2, has_eval=True)
- cls._cluster_spec = {
- "worker": [
- _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
- ],
- "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
- "evaluator": [
- _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
- ]
- }
def setUp(self):
self._model_dir = tempfile.mkdtemp()
- self._event = threading.Event()
+ self._mock_os_env = MockOsEnv()
+ self._mock_context = test.mock.patch.object(os, "environ",
+ self._mock_os_env)
super(DistributeCoordinatorIntegrationTest, self).setUp()
+ self._mock_context.__enter__()
+
+ def tearDown(self):
+ self._mock_context.__exit__(None, None, None)
+ super(DistributeCoordinatorIntegrationTest, self).tearDown()
def dataset_input_fn(self, x, y, batch_size, shuffle):
@@ -391,43 +287,17 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
self._inspect_train_and_eval_events(estimator)
- def _mock_run_distribute_coordinator(
- self,
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=dc.CoordinatorMode.STANDALONE_CLIENT,
- cluster_spec=None,
- session_config=None):
- # Calls the origial `run_distribute_coordinator` method but gets task config
- # from environment variables and then signals the caller.
- task_type = None
- task_id = None
- if not cluster_spec:
- cluster_spec = None
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- if not cluster_spec:
- cluster_spec = tf_config.get("cluster", {})
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", task_type)
- task_id = int(task_env.get("index", task_id))
- self._event.set()
- original_run_distribute_coordinator(
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=mode,
- cluster_spec=cluster_spec,
- task_type=task_type,
- task_id=task_id,
- session_config=session_config)
-
- def _task_thread(self, train_distribute, eval_distribute):
- with test.mock.patch.object(dc, "run_distribute_coordinator",
- self._mock_run_distribute_coordinator):
+ def _mock_run_std_server(self, *args, **kwargs):
+ ret = original_run_std_server(*args, **kwargs)
+ # Wait for all std servers to be brought up in order to reduce the chance of
+ # remote sessions taking local ports that have been assigned to std servers.
+ self._barrier.wait()
+ return ret
+
+ def _task_thread(self, train_distribute, eval_distribute, tf_config):
+ os.environ["TF_CONFIG"] = json.dumps(tf_config)
+ with test.mock.patch.object(dc, "_run_std_server",
+ self._mock_run_std_server):
self._complete_flow(train_distribute, eval_distribute)
def _run_task_in_thread(self, cluster_spec, task_type, task_id,
@@ -448,13 +318,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
"index": task_id
}
}
- self._event.clear()
t = threading.Thread(
- target=self._task_thread, args=(train_distribute, eval_distribute))
- with test.mock.patch.dict("os.environ",
- {"TF_CONFIG": json.dumps(tf_config)}):
- t.start()
- self._event.wait()
+ target=self._task_thread,
+ args=(train_distribute, eval_distribute, tf_config))
+ t.start()
return t
def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
@@ -489,7 +356,11 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=2, has_eval=True)
+ # 3 workers, 2 ps and 1 evaluator.
+ self._barrier = dc._Barrier(6)
+
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
for task_type, ts in threads.items():
@@ -516,7 +387,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=0, has_eval=True)
+ # 3 workers and 1 evaluator.
+ self._barrier = dc._Barrier(4)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
threads[WORKER][0].join()
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 18b4503eff..9f92ba7dde 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -36,9 +36,29 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
+def pick_unused_port():
+ """Returns an unused and unassigned local port."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ global ASSIGNED_PORTS
+ with lock:
+ while True:
+ port = portpicker.pick_unused_port()
+ if port > 10000 and port not in ASSIGNED_PORTS:
+ ASSIGNED_PORTS.add(port)
+ logging.info('Using local port %r', port)
+ return port
+
+
def _create_cluster(num_workers,
num_ps,
has_chief=False,
@@ -49,8 +69,8 @@ def _create_cluster(num_workers,
"""Creates and starts local servers and returns the cluster_spec dict."""
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ worker_ports = [pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
if num_workers > 0:
@@ -58,9 +78,9 @@ def _create_cluster(num_workers,
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
if has_eval:
- cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
if has_chief:
- cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
@@ -139,11 +159,36 @@ def create_in_process_cluster(num_workers,
num_workers,
num_ps=num_ps,
has_chief=has_chief,
+ has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
+def create_cluster_spec(has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ """Create a cluster spec with tasks with unused local ports."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
+ if num_workers:
+ cluster_spec['worker'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec['ps'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
+ return cluster_spec
+
+
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""