aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/estimator/run_config.py38
-rw-r--r--tensorflow/python/estimator/run_config_test.py112
2 files changed, 150 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index b948ce96e0..3d60c63b68 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -25,6 +25,7 @@ import os
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
@@ -484,6 +485,43 @@ class RunConfig(object):
self._init_distributed_setting_from_environment_var(tf_config)
+ # Get session_config only for distributed mode (cluster_spec is present).
+ if not self._session_config and self._cluster_spec:
+ RunConfig._replace(
+ self,
+ allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,
+ session_config=self._get_default_session_config())
+
+ def _get_default_session_config(self):
+ """Returns None or tf.ConfigProto instance with default device_filters set.
+
+ Device filters are set such that chief/master and worker communicates with
+ only ps. session_config=None for evaluators or any other TaskType.
+ """
+
+ rewrite_opts = rewriter_config_pb2.RewriterConfig(
+ meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
+ graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
+
+ device_filters = None
+ if self._task_type == TaskType.MASTER:
+ device_filters = ['/job:ps', '/job:master']
+ elif self._task_type == TaskType.CHIEF:
+ device_filters = ['/job:ps', '/job:chief']
+ elif self._task_type == TaskType.WORKER:
+ device_filters = ['/job:ps', '/job:worker/task:%d' % self._task_id]
+ elif self._task_type == TaskType.PS:
+ device_filters = ['/job:ps', '/job:worker', '/job:master']
+ else:
+ # If the task_type is `EVALUATOR` or something other than the ones in
+ # TaskType then don't set any device filters.
+ return None
+
+ return config_pb2.ConfigProto(
+ allow_soft_placement=True,
+ graph_options=graph_opts,
+ device_filters=device_filters)
+
def _init_distributed_setting_from_environment_var(self, tf_config):
"""Initialize distributed properties based on `tf_config`."""
diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py
index c8b12605e1..06df7cb9dd 100644
--- a/tensorflow/python/estimator/run_config_test.py
+++ b/tensorflow/python/estimator/run_config_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import json
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
@@ -290,6 +291,7 @@ class RunConfigDistributedSettingTest(test.TestCase):
expected_num_worker_replicas=1,
expected_num_ps_replicas=0)
self.assertEqual(0, run_config.global_id_in_cluster)
+ self.assertIsNone(run_config.session_config, None)
def test_session_master_for_local(self):
tf_config = {'session_master': '_my_master'}
@@ -1119,5 +1121,115 @@ class RunConfigModelDirTest(test.TestCase):
_create_run_config_with_cluster_spec(tf_config)
+class RunConfigSessionConfigTest(test.TestCase):
+
+ def _assert_equal_session_config(self, session_config,
+ expected_device_filters):
+
+ rewrite_opts = rewriter_config_pb2.RewriterConfig(
+ meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
+ graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
+ expected_session_config = config_pb2.ConfigProto(
+ allow_soft_placement=True,
+ graph_options=graph_opts,
+ device_filters=expected_device_filters)
+ self.assertEqual(session_config, expected_session_config)
+
+ def test_master_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.MASTER,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:master'])
+
+ def test_chief_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.CHIEF,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:chief'])
+
+ def test_worker_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.WORKER,
+ 'index': 1
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:worker/task:1'])
+
+ def test_ps_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.PS,
+ 'index': 1
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:worker', '/job:master'])
+
+ def test_evaluator_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.EVALUATOR,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self.assertIsNone(run_config.session_config)
+
+ def test_other_type_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ 'other_type': ['host3:1', 'host4:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': 'other_type',
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self.assertIsNone(run_config.session_config)
+
+
if __name__ == '__main__':
test.main()