diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-11 11:04:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-11 11:10:59 -0700 |
commit | 1d6973d68b5d617e3a2dbf935643d0c0e4dcdac5 (patch) | |
tree | 57a4bdf181f42b8cd0bb4f4394bcf00740e4ab91 /tensorflow/contrib/training | |
parent | e1562e72c197ec830547a051ddfe0f720acb9f67 (diff) |
RELNOTES: This allows the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)").
NOTE: the use of '.' in variable names is now allowed, but it is not recommended.
PiperOrigin-RevId: 196278660
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam_test.py | 15 |
2 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index f0418f04ba..3beb7bfe30 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -34,7 +34,7 @@ from tensorflow.python.util import deprecation # where <rhs> is either a single token or [] enclosed list of tokens. # For example: "var[1] = a" or "x = [1,2,3]" PARAM_RE = re.compile(r""" - (?P<name>[a-zA-Z][\w]*) # variable name: "var" or "x" + (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None \s*=\s* ((?P<val>[^,\[]*) # single value: "a" or None @@ -200,6 +200,13 @@ def parse_values(values, type_map): If a hyperparameter name in both an index assignment and scalar assignment, a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + The `value` in `name=value` must follows the syntax according to the type of the parameter: diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 11fd15b527..660c97f25e 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -118,6 +118,21 @@ class HParamsTest(test.TestCase): self.assertEqual('2.3"', hparams2.c_c) self.assertEqual('/a=b/c/d', hparams2.d) + def testWithPeriodInVariableName(self): + hparams = hparam.HParams() + hparams.add_hparam(name='a.b', value=0.0) + hparams.parse('a.b=1.0') + self.assertEqual(1.0, getattr(hparams, 'a.b')) + hparams.add_hparam(name='c.d', value=0.0) + with self.assertRaisesRegexp(ValueError, 'Could not parse'): + hparams.parse('c.d=abc') + hparams.add_hparam(name='e.f', value='') + hparams.parse('e.f=abc') + self.assertEqual('abc', getattr(hparams, 'e.f')) + hparams.add_hparam(name='d..', value=0.0) + hparams.parse('d..=10.0') + self.assertEqual(10.0, getattr(hparams, 'd..')) + def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') hparams.override_from_dict({'a': -2, 'c': 'identity'}) |