diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-23 13:09:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-23 13:16:00 -0700 |
commit | 2d4214415269bee2c8c98d5466c540e4004652fd (patch) | |
tree | 08e9954de2577c86d890f8fa22b956bc21cf3549 /tensorflow/compiler/tests/adam_test.py | |
parent | a473f435cf7345cb9dc2efacb471a7f318141a9b (diff) |
This change makes casts to bfloat16 use rounding instead of truncation by default. The motivation is that rounding achieves better accuracy than truncation.
PiperOrigin-RevId: 209985826
Diffstat (limited to 'tensorflow/compiler/tests/adam_test.py')
-rw-r--r-- | tensorflow/compiler/tests/adam_test.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py index 0d2e4d0296..df0f21471a 100644 --- a/tensorflow/compiler/tests/adam_test.py +++ b/tensorflow/compiler/tests/adam_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -53,7 +54,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testBasic(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -95,7 +96,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testTensorLearningRate(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) @@ -137,7 +138,7 @@ class AdamOptimizerTest(xla_test.XLATestCase): def testSharing(self): for dtype in self.float_types: # TODO: test fails for float16 due to excessive precision requirements. - if dtype == np.float16: + if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]: continue with self.test_session(), self.test_scope(): variable_scope.get_variable_scope().set_use_resource(True) |