diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/estimators/run_config_test.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/run_config_test.py | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py index 14cef7cc43..6d39a9ad13 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config_test.py @@ -257,6 +257,51 @@ class RunConfigTest(test.TestCase): self.assertNotEqual(expected_uid, new_config.uid()) self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) + def test_uid_for_whitelist(self): + whitelist = ["model_dir"] + config = run_config_lib.RunConfig( + tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR) + + expected_uid = config.uid(whitelist) + self.assertEqual(expected_uid, config.uid(whitelist)) + + new_config = config.replace(model_dir=ANOTHER_TEST_DIR) + self.assertEqual(TEST_DIR, config.model_dir) + self.assertEqual(expected_uid, new_config.uid(whitelist)) + self.assertEqual(ANOTHER_TEST_DIR, new_config.model_dir) + + def test_uid_for_default_whitelist(self): + config = run_config_lib.RunConfig( + tf_random_seed=11, + save_summary_steps=12, + save_checkpoints_steps=13, + save_checkpoints_secs=14, + session_config=15, + keep_checkpoint_max=16, + keep_checkpoint_every_n_hours=17) + self.assertEqual(11, config.tf_random_seed) + self.assertEqual(12, config.save_summary_steps) + self.assertEqual(13, config.save_checkpoints_steps) + self.assertEqual(14, config.save_checkpoints_secs) + self.assertEqual(15, config.session_config) + self.assertEqual(16, config.keep_checkpoint_max) + self.assertEqual(17, config.keep_checkpoint_every_n_hours) + + new_config = run_config_lib.RunConfig( + tf_random_seed=21, + save_summary_steps=22, + save_checkpoints_steps=23, + save_checkpoints_secs=24, + session_config=25, + keep_checkpoint_max=26, + keep_checkpoint_every_n_hours=27) + self.assertEqual(config.uid(), new_config.uid()) + # model_dir is not on the default whitelist. + self.assertNotEqual(config.uid(whitelist=[]), + new_config.uid(whitelist=[])) + new_config = new_config.replace(model_dir=ANOTHER_TEST_DIR) + self.assertNotEqual(config.uid(), new_config.uid()) + def test_uid_for_deepcopy(self): tf_config = { "cluster": { |