diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-06-18 09:57:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-18 09:59:59 -0700 |
commit | e80732c9895d1283af9b98d6277ad1a1015e2e9a (patch) | |
tree | 14895657394f9cdfed8435460e37fe89a45ba599 /tensorflow/contrib/opt | |
parent | 8ecf506fb8464dd273ce59f512f5e20d37dd5cfd (diff) |
Merge changes from github.
PiperOrigin-RevId: 201011811
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/adamax_test.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/opt/python/training/model_average_optimizer.py | 2 |
2 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index 21bf3f5313..915e6504e1 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -224,8 +224,10 @@ class AdaMaxOptimizerTest(test.TestCase): var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), + rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), + rtol=1e-2) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index a7c97a1da2..b6b10e500b 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -62,7 +62,7 @@ class ModelAverageCustomGetter(object): """ def __init__(self, worker_device): - """Create a new `ElasticAverageCustomGetter`. + """Create a new `ModelAverageCustomGetter`. Args: worker_device: String. Name of the `worker` job. |