aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/run_config_test.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config_test.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py
index 14cef7cc43..6d39a9ad13 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py
@@ -257,6 +257,51 @@ class RunConfigTest(test.TestCase):
self.assertNotEqual(expected_uid, new_config.uid())
self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
+ def test_uid_for_whitelist(self):
+ whitelist = ["model_dir"]
+ config = run_config_lib.RunConfig(
+ tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
+
+ expected_uid = config.uid(whitelist)
+ self.assertEqual(expected_uid, config.uid(whitelist))
+
+ new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
+ self.assertEqual(TEST_DIR, config.model_dir)
+ self.assertEqual(expected_uid, new_config.uid(whitelist))
+ self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir)
+
+ def test_uid_for_default_whitelist(self):
+ config = run_config_lib.RunConfig(
+ tf_random_seed=11,
+ save_summary_steps=12,
+ save_checkpoints_steps=13,
+ save_checkpoints_secs=14,
+ session_config=15,
+ keep_checkpoint_max=16,
+ keep_checkpoint_every_n_hours=17)
+ self.assertEqual(11, config.tf_random_seed)
+ self.assertEqual(12, config.save_summary_steps)
+ self.assertEqual(13, config.save_checkpoints_steps)
+ self.assertEqual(14, config.save_checkpoints_secs)
+ self.assertEqual(15, config.session_config)
+ self.assertEqual(16, config.keep_checkpoint_max)
+ self.assertEqual(17, config.keep_checkpoint_every_n_hours)
+
+ new_config = run_config_lib.RunConfig(
+ tf_random_seed=21,
+ save_summary_steps=22,
+ save_checkpoints_steps=23,
+ save_checkpoints_secs=24,
+ session_config=25,
+ keep_checkpoint_max=26,
+ keep_checkpoint_every_n_hours=27)
+ self.assertEqual(config.uid(), new_config.uid())
+ # model_dir is not on the default whitelist.
+ self.assertNotEqual(config.uid(whitelist=[]),
+ new_config.uid(whitelist=[]))
+ new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR)
+ self.assertNotEqual(config.uid(), new_config.uid())
+
def test_uid_for_deepcopy(self):
tf_config = {
"cluster": {