aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/run_config_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 17:50:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 17:53:35 -0700
commiteacb2c097afb3b688fbc6953c3c4436635286c78 (patch)
treeb258e27609b1ba62bdd9632640ab17b021221af9 /tensorflow/python/estimator/run_config_test.py
parenta350f66ed250c3dee43cc27b0778c3759f07e810 (diff)
Setting default device_filters in session_config for tf.Estimator's RunConfig
PiperOrigin-RevId: 201617644
Diffstat (limited to 'tensorflow/python/estimator/run_config_test.py')
-rw-r--r--tensorflow/python/estimator/run_config_test.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py
index c8b12605e1..06df7cb9dd 100644
--- a/tensorflow/python/estimator/run_config_test.py
+++ b/tensorflow/python/estimator/run_config_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import json
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
@@ -290,6 +291,7 @@ class RunConfigDistributedSettingTest(test.TestCase):
expected_num_worker_replicas=1,
expected_num_ps_replicas=0)
self.assertEqual(0, run_config.global_id_in_cluster)
+ self.assertIsNone(run_config.session_config, None)
def test_session_master_for_local(self):
tf_config = {'session_master': '_my_master'}
@@ -1119,5 +1121,115 @@ class RunConfigModelDirTest(test.TestCase):
_create_run_config_with_cluster_spec(tf_config)
+class RunConfigSessionConfigTest(test.TestCase):
+
+ def _assert_equal_session_config(self, session_config,
+ expected_device_filters):
+
+ rewrite_opts = rewriter_config_pb2.RewriterConfig(
+ meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
+ graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
+ expected_session_config = config_pb2.ConfigProto(
+ allow_soft_placement=True,
+ graph_options=graph_opts,
+ device_filters=expected_device_filters)
+ self.assertEqual(session_config, expected_session_config)
+
+ def test_master_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.MASTER,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:master'])
+
+ def test_chief_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.CHIEF,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:chief'])
+
+ def test_worker_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.WORKER,
+ 'index': 1
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:worker/task:1'])
+
+ def test_ps_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.PS,
+ 'index': 1
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:worker', '/job:master'])
+
+ def test_evaluator_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.EVALUATOR,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self.assertIsNone(run_config.session_config)
+
+ def test_other_type_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ 'other_type': ['host3:1', 'host4:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': 'other_type',
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self.assertIsNone(run_config.session_config)
+
+
if __name__ == '__main__':
test.main()