aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-04-30 10:00:24 -0400
committerGravatar Shanqing Cai <cais@google.com>2018-04-30 10:00:24 -0400
commit7132e3cdab43c7e2e176f9229ea4fa679098e1e5 (patch)
tree0b994288004176863c0e67ef646c715467886b66 /tensorflow/contrib/training
parent6f3cc9d368a17646f5838e36be3b1c25bf4534fe (diff)
parent914796d5e9bc7b0c619b53c7eb24cfe7d6c7fb9b (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py10
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py16
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index 6c59b68053..f0418f04ba 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -502,6 +502,16 @@ class HParams(object):
'Must pass a list for multi-valued parameter: %s.' % name)
setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
+ def del_hparam(self, name):
+ """Removes the hyperparameter with key 'name'.
+
+ Args:
+ name: Name of the hyperparameter.
+ """
+ if hasattr(self, name):
+ delattr(self, name)
+ del self._hparam_types[name]
+
def parse(self, values):
"""Override hyperparameter values, parsing new values from a string.
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 96eff86d8d..11fd15b527 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -439,6 +439,22 @@ class HParamsTest(test.TestCase):
self.assertEqual(123, hparams.get('unknown', 123))
self.assertEqual([1, 2, 3], hparams.get('unknown', [1, 2, 3]))
+ def testDel(self):
+ hparams = hparam.HParams(aaa=1, b=2.0)
+
+ with self.assertRaises(ValueError):
+ hparams.set_hparam('aaa', 'will fail')
+
+ with self.assertRaises(ValueError):
+ hparams.add_hparam('aaa', 'will fail')
+
+ hparams.del_hparam('aaa')
+ hparams.add_hparam('aaa', 'will work')
+ self.assertEqual('will work', hparams.get('aaa'))
+
+ hparams.set_hparam('aaa', 'still works')
+ self.assertEqual('still works', hparams.get('aaa'))
+
if __name__ == '__main__':
test.main()