diff options
author | Yifei Feng <yifeif@google.com> | 2018-05-15 10:58:43 -0700 |
---|---|---|
committer | Yifei Feng <yifeif@google.com> | 2018-05-15 10:58:43 -0700 |
commit | 41ed2381d9d958abf6c93eb8e49e88282a5099ae (patch) | |
tree | 03142a327497f65199f8c9ae86062de1ce1c1e71 /tensorflow/contrib/training | |
parent | de4a6e646be56ca59c78dd6f92f8f6bcc7196696 (diff) | |
parent | a7a3bb3df12c632b81bf1b23f8405f92a0c903c3 (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam_test.py | 15 |
2 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index f0418f04ba..3beb7bfe30 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -34,7 +34,7 @@ from tensorflow.python.util import deprecation # where <rhs> is either a single token or [] enclosed list of tokens. # For example: "var[1] = a" or "x = [1,2,3]" PARAM_RE = re.compile(r""" - (?P<name>[a-zA-Z][\w]*) # variable name: "var" or "x" + (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None \s*=\s* ((?P<val>[^,\[]*) # single value: "a" or None @@ -200,6 +200,13 @@ def parse_values(values, type_map): If a hyperparameter name in both an index assignment and scalar assignment, a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + The `value` in `name=value` must follows the syntax according to the type of the parameter: diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 11fd15b527..660c97f25e 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -118,6 +118,21 @@ class HParamsTest(test.TestCase): self.assertEqual('2.3"', hparams2.c_c) self.assertEqual('/a=b/c/d', hparams2.d) + def testWithPeriodInVariableName(self): + hparams = hparam.HParams() + hparams.add_hparam(name='a.b', value=0.0) + hparams.parse('a.b=1.0') + self.assertEqual(1.0, getattr(hparams, 'a.b')) + hparams.add_hparam(name='c.d', value=0.0) + with self.assertRaisesRegexp(ValueError, 'Could not parse'): + hparams.parse('c.d=abc') + hparams.add_hparam(name='e.f', value='') + hparams.parse('e.f=abc') + self.assertEqual('abc', getattr(hparams, 'e.f')) + hparams.add_hparam(name='d..', value=0.0) + hparams.parse('d..=10.0') + self.assertEqual(10.0, getattr(hparams, 'd..')) + def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') hparams.override_from_dict({'a': -2, 'c': 'identity'}) |