aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-04-30 16:11:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 16:14:39 -0700
commit18343616da47a9c3eab79b5028ac3d8bf786f2ff (patch)
tree69fad84783fef47790c49416b2e8b5f2897d52e0
parent30fcdecc05e6b25ab8d451997904e40b2a76acd4 (diff)
[XLA] Change the TF2XLA bridge to perform F16 reduction using F32 data type.
Add test cases to test that reduce sum for bfloat16 and float16 doesn't lose too much precision. PiperOrigin-RevId: 194862078
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py64
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc2
2 files changed, 65 insertions, 1 deletions
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 2c084b04fa..7420724bdb 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import itertools
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
@@ -155,5 +156,68 @@ class ReduceOpsTest(XLATestCase):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
+class ReduceOpPrecisionTest(XLATestCase):
+
+ def _testReduceSum(self,
+ expected_result,
+ dtype,
+ test_inputs,
+ rtol=1e-3,
+ atol=1e-4):
+ """Tests reduce sum on a list of input arrays.
+
+ For each array in test_inputs, check that performing reduce sum on the array
+ produces a value that is close to the expected result.
+
+ Args:
+ expected_result: the expected result.
+ dtype: the data type of the reduce sum operation.
+ test_inputs: a list of input arrays for the reduce sum operation.
+ rtol: the relative error.
+ atol: the absolute error.
+ """
+
+ for test_input in test_inputs:
+ with self.test_session() as sess:
+ with self.test_scope():
+ a = array_ops.placeholder(dtype)
+ index = array_ops.placeholder(dtypes.int32)
+ out = math_ops.reduce_sum(a, index)
+ result = sess.run(out, {
+ a: np.array(test_input, dtype=dtype),
+ index: [0]
+ })
+ # Compare the results using float32 type.
+ self.assertAllClose(
+ np.float32(result),
+ np.float32(expected_result),
+ rtol=rtol,
+ atol=atol)
+
+ def testReduceSumF16(self):
+ """Tests the reduce sum of float16 doesn't lose too much precision."""
+
+ if np.float16 not in self.all_types:
+ return
+
+ f16_max = np.finfo(np.float16).max
+ self._testReduceSum(
+ f16_max, np.float16,
+ itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3))
+
+ def testReduceSumBF16(self):
+ """Tests the reduce sum of bfloat16 doesn't lose too much precision."""
+
+ if dtypes.bfloat16.as_numpy_dtype not in self.all_types:
+ return
+
+ bf16_max = np.float32(dtypes.bfloat16.max)
+ f32_max = dtypes.float32.max
+ value = min(bf16_max, f32_max - bf16_max)
+ self._testReduceSum(
+ dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype,
+ itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3))
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 62a5114837..a3deb02a1f 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -278,7 +278,7 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth,
}
DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
- if (dtype == DT_BFLOAT16) {
+ if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
return DT_FLOAT;
}
return dtype;