aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-05-15 10:58:43 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-05-15 10:58:43 -0700
commit41ed2381d9d958abf6c93eb8e49e88282a5099ae (patch)
tree03142a327497f65199f8c9ae86062de1ce1c1e71 /tensorflow/contrib/training
parentde4a6e646be56ca59c78dd6f92f8f6bcc7196696 (diff)
parenta7a3bb3df12c632b81bf1b23f8405f92a0c903c3 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/hparam.py9
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py15
2 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index f0418f04ba..3beb7bfe30 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -34,7 +34,7 @@ from tensorflow.python.util import deprecation
# where <rhs> is either a single token or [] enclosed list of tokens.
# For example: "var[1] = a" or "x = [1,2,3]"
PARAM_RE = re.compile(r"""
- (?P<name>[a-zA-Z][\w]*) # variable name: "var" or "x"
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
\s*=\s*
((?P<val>[^,\[]*) # single value: "a" or None
@@ -200,6 +200,13 @@ def parse_values(values, type_map):
If a hyperparameter name in both an index assignment and scalar assignment,
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
+ The hyperparameter name may contain '.' symbols, which will result in an
+ attribute name that is only accessible through the getattr and setattr
+ functions. (And must be first explicit added through add_hparam.)
+
+ WARNING: Use of '.' in your variable names is allowed, but is not well
+ supported and not recommended.
+
The `value` in `name=value` must follows the syntax according to the
type of the parameter:
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 11fd15b527..660c97f25e 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -118,6 +118,21 @@ class HParamsTest(test.TestCase):
self.assertEqual('2.3"', hparams2.c_c)
self.assertEqual('/a=b/c/d', hparams2.d)
+ def testWithPeriodInVariableName(self):
+ hparams = hparam.HParams()
+ hparams.add_hparam(name='a.b', value=0.0)
+ hparams.parse('a.b=1.0')
+ self.assertEqual(1.0, getattr(hparams, 'a.b'))
+ hparams.add_hparam(name='c.d', value=0.0)
+ with self.assertRaisesRegexp(ValueError, 'Could not parse'):
+ hparams.parse('c.d=abc')
+ hparams.add_hparam(name='e.f', value='')
+ hparams.parse('e.f=abc')
+ self.assertEqual('abc', getattr(hparams, 'e.f'))
+ hparams.add_hparam(name='d..', value=0.0)
+ hparams.parse('d..=10.0')
+ self.assertEqual(10.0, getattr(hparams, 'd..'))
+
def testSetFromMap(self):
hparams = hparam.HParams(a=1, b=2.0, c='tanh')
hparams.override_from_dict({'a': -2, 'c': 'identity'})