aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/adam_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 13:09:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 13:16:00 -0700
commit2d4214415269bee2c8c98d5466c540e4004652fd (patch)
tree08e9954de2577c86d890f8fa22b956bc21cf3549 /tensorflow/compiler/tests/adam_test.py
parenta473f435cf7345cb9dc2efacb471a7f318141a9b (diff)
This change makes casts to bfloat16 use rounding instead of truncation by default. The motivation is that rounding achieves better accuracy than truncation.
PiperOrigin-RevId: 209985826
Diffstat (limited to 'tensorflow/compiler/tests/adam_test.py')
-rw-r--r--tensorflow/compiler/tests/adam_test.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index 0d2e4d0296..df0f21471a 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -53,7 +54,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
@@ -95,7 +96,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testTensorLearningRate(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)
@@ -137,7 +138,7 @@ class AdamOptimizerTest(xla_test.XLATestCase):
def testSharing(self):
for dtype in self.float_types:
# TODO: test fails for float16 due to excessive precision requirements.
- if dtype == np.float16:
+ if dtype in [np.float16, dtypes.bfloat16.as_numpy_dtype]:
continue
with self.test_session(), self.test_scope():
variable_scope.get_variable_scope().set_use_resource(True)