diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-26 08:52:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-26 08:55:17 -0700 |
commit | a5a1e9e43131b387395930f38234fc10b02d874b (patch) | |
tree | 0cd74c0509e8c8097f53535ce428800aff8cb5de /tensorflow/contrib/training | |
parent | c3436d6757a77ab1fefd3f6000a1e961a9ab9881 (diff) |
Updated test (but not source) of https://www.tensorflow.org/api_docs/python/tf/contrib/training/HParams to show that it allows '=' in the values.
PiperOrigin-RevId: 190470578
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/hparam_test.py | 42 |
1 files changed, 32 insertions, 10 deletions
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index 16397622ed..96eff86d8d 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -38,40 +38,60 @@ class HParamsTest(test.TestCase): self.assertFalse('bar' in hparams) def testSomeValues(self): - hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6') - self.assertDictEqual({'aaa': 1, 'b': 2.0, 'c_c': 'relu6'}, hparams.values()) - expected_str = '[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\')]' + hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6', d='/a/b=c/d') + self.assertDictEqual( + {'aaa': 1, 'b': 2.0, 'c_c': 'relu6', 'd': '/a/b=c/d'}, + hparams.values()) + expected_str = ('[(\'aaa\', 1), (\'b\', 2.0), (\'c_c\', \'relu6\'), ' + '(\'d\', \'/a/b=c/d\')]') self.assertEqual(expected_str, str(hparams.__str__())) self.assertEqual(expected_str, str(hparams)) self.assertEqual(1, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('aaa=12') self.assertDictEqual({ 'aaa': 12, 'b': 2.0, - 'c_c': 'relu6' + 'c_c': 'relu6', + 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(2.0, hparams.b) self.assertEqual('relu6', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=relu4, b=-2.0e10') self.assertDictEqual({ 'aaa': 12, 'b': -2.0e10, - 'c_c': 'relu4' + 'c_c': 'relu4', + 'd': '/a/b=c/d' }, hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(-2.0e10, hparams.b) self.assertEqual('relu4', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=,b=0,') - self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': ''}, hparams.values()) + self.assertDictEqual({'aaa': 12, 'b': 0, 'c_c': '', 'd': '/a/b=c/d'}, + hparams.values()) self.assertEqual(12, hparams.aaa) self.assertEqual(0.0, hparams.b) self.assertEqual('', hparams.c_c) + self.assertEqual('/a/b=c/d', hparams.d) hparams.parse('c_c=2.3",b=+2,') self.assertEqual(2.0, hparams.b) self.assertEqual('2.3"', hparams.c_c) + hparams.parse('d=/a/b/c/d,aaa=11,') + self.assertEqual(11, hparams.aaa) + self.assertEqual(2.0, hparams.b) + self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a/b/c/d', hparams.d) + hparams.parse('b=1.5,d=/a=b/c/d,aaa=10,') + self.assertEqual(10, hparams.aaa) + self.assertEqual(1.5, hparams.b) + self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a=b/c/d', hparams.d) with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('x=123') with self.assertRaisesRegexp(ValueError, 'Could not parse'): @@ -84,17 +104,19 @@ class HParamsTest(test.TestCase): hparams.parse('b=relu') with self.assertRaisesRegexp(ValueError, 'Must not pass a list'): hparams.parse('aaa=[123]') - self.assertEqual(12, hparams.aaa) - self.assertEqual(2.0, hparams.b) + self.assertEqual(10, hparams.aaa) + self.assertEqual(1.5, hparams.b) self.assertEqual('2.3"', hparams.c_c) + self.assertEqual('/a=b/c/d', hparams.d) # Exports to proto. hparam_def = hparams.to_proto() # Imports from proto. hparams2 = hparam.HParams(hparam_def=hparam_def) # Verifies that all hparams are restored. - self.assertEqual(12, hparams2.aaa) - self.assertEqual(2.0, hparams2.b) + self.assertEqual(10, hparams2.aaa) + self.assertEqual(1.5, hparams2.b) self.assertEqual('2.3"', hparams2.c_c) + self.assertEqual('/a=b/c/d', hparams2.d) def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') |