aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Sherry Moore <sherrym@google.com>2018-04-29 09:56:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-29 10:01:27 -0700
commit45529aaac3f5c1d290c285a4e86c434600ec2d92 (patch)
tree56e8ef2b1aacb1b5b442d55792758ee6f090b5a4 /tensorflow/contrib/training
parent2e1f3efcb34380df1441660d9759b44bb07cf1cd (diff)
Added del_hparam(), the counter part of add_hparam.
PiperOrigin-RevId: 194711291
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py10
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py16
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index 6c59b68053..f0418f04ba 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -502,6 +502,16 @@ class HParams(object):
'Must pass a list for multi-valued parameter: %s.' % name)
setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
+ def del_hparam(self, name):
+ """Removes the hyperparameter with key 'name'.
+
+ Args:
+ name: Name of the hyperparameter.
+ """
+ if hasattr(self, name):
+ delattr(self, name)
+ del self._hparam_types[name]
+
def parse(self, values):
"""Override hyperparameter values, parsing new values from a string.
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 96eff86d8d..11fd15b527 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -439,6 +439,22 @@ class HParamsTest(test.TestCase):
self.assertEqual(123, hparams.get('unknown', 123))
self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
+ def testDel(self):
+ hparams = hparam.HParams(aaa=1, b=2.0)
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('aaa', 'will fail')
+
+ with self.assertRaises(ValueError):
+ hparams.add_hparam('aaa', 'will fail')
+
+ hparams.del_hparam('aaa')
+ hparams.add_hparam('aaa', 'will work')
+ self.assertEqual('will work', hparams.get('aaa'))
+
+ hparams.set_hparam('aaa', 'still works')
+ self.assertEqual('still works', hparams.get('aaa'))
+
if __name__ == '__main__':
test.main()