aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-26 08:52:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 08:55:17 -0700
commita5a1e9e43131b387395930f38234fc10b02d874b (patch)
tree0cd74c0509e8c8097f53535ce428800aff8cb5de /tensorflow/contrib/training
parentc3436d6757a77ab1fefd3f6000a1e961a9ab9881 (diff)
Updated test (but not source) of https://www.tensorflow.org/api_docs/python/tf/contrib/training/HParams to show that it allows '=' in the values.
PiperOrigin-RevId: 190470578
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/hparam_test.py42
1 files changed, 32 insertions, 10 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index 16397622ed..96eff86d8d 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -38,40 +38,60 @@ class HParamsTest(test.TestCase):
self.assertFalse('bar' in hparams)
def testSomeValues(self):
- hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6')
- self.assertDictEqual({'aaa': 1, 'b': 2.0, 'c_c': 'relu6'}, hparams.values())
- expected_str = '[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\')]'
+ hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d')
+ self.assertDictEqual(
+ {'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d'},
+ hparams.values())
+ expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), '
+ '(\'d\', \'/a/b=c/d\')]')
self.assertEqual(expected_str, str(hparams.__str__()))
self.assertEqual(expected_str, str(hparams))
self.assertEqual(1, hparams.aaa)
self.assertEqual(2.0, hparams.b)
self.assertEqual('relu6', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('aaa=12')
self.assertDictEqual({
'aaa': 12,
'b': 2.0,
- 'c_c': 'relu6'
+ 'c_c': 'relu6',
+ 'd': '/a/b=c/d'
}, hparams.values())
self.assertEqual(12, hparams.aaa)
self.assertEqual(2.0, hparams.b)
self.assertEqual('relu6', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('c_c=relu4, b=-2.0e10')
self.assertDictEqual({
'aaa': 12,
'b': -2.0e10,
- 'c_c': 'relu4'
+ 'c_c': 'relu4',
+ 'd': '/a/b=c/d'
}, hparams.values())
self.assertEqual(12, hparams.aaa)
self.assertEqual(-2.0e10, hparams.b)
self.assertEqual('relu4', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('c_c=,b=0,')
- self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': ''}, hparams.values())
+ self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d'},
+ hparams.values())
self.assertEqual(12, hparams.aaa)
self.assertEqual(0.0, hparams.b)
self.assertEqual('', hparams.c_c)
+ self.assertEqual('/a/b=c/d', hparams.d)
hparams.parse('c_c=2.3",b=+2,')
self.assertEqual(2.0, hparams.b)
self.assertEqual('2.3"', hparams.c_c)
+ hparams.parse('d=/a/b/c/d,aaa=11,')
+ self.assertEqual(11, hparams.aaa)
+ self.assertEqual(2.0, hparams.b)
+ self.assertEqual('2.3"', hparams.c_c)
+ self.assertEqual('/a/b/c/d', hparams.d)
+ hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,')
+ self.assertEqual(10, hparams.aaa)
+ self.assertEqual(1.5, hparams.b)
+ self.assertEqual('2.3"', hparams.c_c)
+ self.assertEqual('/a=b/c/d', hparams.d)
with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'):
hparams.parse('x=123')
with self.assertRaisesRegexp(ValueError, 'Could not parse'):
@@ -84,17 +104,19 @@ class HParamsTest(test.TestCase):
hparams.parse('b=relu')
with self.assertRaisesRegexp(ValueError, 'Must not pass a list'):
hparams.parse('aaa=[123]')
- self.assertEqual(12, hparams.aaa)
- self.assertEqual(2.0, hparams.b)
+ self.assertEqual(10, hparams.aaa)
+ self.assertEqual(1.5, hparams.b)
self.assertEqual('2.3"', hparams.c_c)
+ self.assertEqual('/a=b/c/d', hparams.d)
# Exports to proto.
hparam_def = hparams.to_proto()
# Imports from proto.
hparams2 = hparam.HParams(hparam_def=hparam_def)
# Verifies that all hparams are restored.
- self.assertEqual(12, hparams2.aaa)
- self.assertEqual(2.0, hparams2.b)
+ self.assertEqual(10, hparams2.aaa)
+ self.assertEqual(1.5, hparams2.b)
self.assertEqual('2.3"', hparams2.c_c)
+ self.assertEqual('/a=b/c/d', hparams2.d)
def testSetFromMap(self):
hparams = hparam.HParams(a=1, b=2.0, c='tanh')