diff options
author | Shanqing Cai <cais@google.com> | 2018-04-30 10:00:24 -0400 |
---|---|---|
committer | Shanqing Cai <cais@google.com> | 2018-04-30 10:00:24 -0400 |
commit | 7132e3cdab43c7e2e176f9229ea4fa679098e1e5 (patch) | |
tree | 0b994288004176863c0e67ef646c715467886b66 /tensorflow/contrib/training | |
parent | 6f3cc9d368a17646f5838e36be3b1c25bf4534fe (diff) | |
parent | 914796d5e9bc7b0c619b53c7eb24cfe7d6c7fb9b (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam_test.py | 16 |
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 6c59b68053..f0418f04ba 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -502,6 +502,16 @@ class HParams(object): 'Must pass a list for multi-valued parameter: %s.' % name) setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + def parse(self, values): """Override hyperparameter values, parsing new values from a string. diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 96eff86d8d..11fd15b527 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -439,6 +439,22 @@ class HParamsTest(test.TestCase): self.assertEqual(123, hparams.get('unknown', 123)) self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3])) + def testDel(self): + hparams = hparam.HParams(aaa=1, b=2.0) + + with self.assertRaises(ValueError): + hparams.set_hparam('aaa', 'will fail') + + with self.assertRaises(ValueError): + hparams.add_hparam('aaa', 'will fail') + + hparams.del_hparam('aaa') + hparams.add_hparam('aaa', 'will work') + self.assertEqual('will work', hparams.get('aaa')) + + hparams.set_hparam('aaa', 'still works') + self.assertEqual('still works', hparams.get('aaa')) + if __name__ == '__main__': test.main() |