aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-11 11:04:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 11:10:59 -0700
commit1d6973d68b5d617e3a2dbf935643d0c0e4dcdac5 (patch)
tree57a4bdf181f42b8cd0bb4f4394bcf00740e4ab91 /tensorflow/contrib/training
parente1562e72c197ec830547a051ddfe0f720acb9f67 (diff)
RELNOTES: This allows the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)").
NOTE: the use of '.' in variable names is now allowed, but it is not recommended. PiperOrigin-RevId: 196278660
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py9
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py15
2 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index f0418f04ba..3beb7bfe30 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -34,7 +34,7 @@ from tensorflow.python.util import deprecation
# where <rhs> is either a single token or [] enclosed list of tokens.
# For example: "var[1] = a" or "x = [1,2,3]"
PARAM_RE = re.compile(r"""
- (?P<name>[a-zA-Z][\w]*) # variable name: "var" or "x"
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
\s*=\s*
((?P<val>[^,\[]*) # single value: "a" or None
@@ -200,6 +200,13 @@ def parse_values(values, type_map):
If a hyperparameter name in both an index assignment and scalar assignment,
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
+ The hyperparameter name may contain '.' symbols, which will result in an
+ attribute name that is only accessible through the getattr and setattr
+ functions. (And must be first explicit added through add_hparam.)
+
+ WARNING: Use of '.' in your variable names is allowed, but is not well
+ supported and not recommended.
+
The `value` in `name=value` must follows the syntax according to the
type of the parameter:
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 11fd15b527..660c97f25e 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -118,6 +118,21 @@ class HParamsTest(test.TestCase):
self.assertEqual('2.3"', hparams2.c_c)
self.assertEqual('/a=b/c/d', hparams2.d)
+ def testWithPeriodInVariableName(self):
+ hparams = hparam.HParams()
+ hparams.add_hparam(name='a.b', value=0.0)
+ hparams.parse('a.b=1.0')
+ self.assertEqual(1.0, getattr(hparams, 'a.b'))
+ hparams.add_hparam(name='c.d', value=0.0)
+ with self.assertRaisesRegexp(ValueError, 'Could not parse'):
+ hparams.parse('c.d=abc')
+ hparams.add_hparam(name='e.f', value='')
+ hparams.parse('e.f=abc')
+ self.assertEqual('abc', getattr(hparams, 'e.f'))
+ hparams.add_hparam(name='d..', value=0.0)
+ hparams.parse('d..=10.0')
+ self.assertEqual(10.0, getattr(hparams, 'd..'))
+
def testSetFromMap(self):
hparams = hparam.HParams(a=1, b=2.0, c='tanh')
hparams.override_from_dict({'a': -2, 'c': 'identity'})