From 18343616da47a9c3eab79b5028ac3d8bf786f2ff Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Mon, 30 Apr 2018 16:11:38 -0700 Subject: [XLA] Change the TF2XLA bridge to perform F16 reduction using F32 data type. Add test cases to test that reduce sum for bfloat16 and float16 doesn't lose too much precision. PiperOrigin-RevId: 194862078 --- tensorflow/compiler/tests/reduce_ops_test.py | 64 ++++++++++++++++++++++++++++ tensorflow/compiler/tf2xla/xla_helpers.cc | 2 +- 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 2c084b04fa..7420724bdb 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import itertools import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -155,5 +156,68 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) +class ReduceOpPrecisionTest(XLATestCase): + + def _testReduceSum(self, + expected_result, + dtype, + test_inputs, + rtol=1e-3, + atol=1e-4): + """Tests reduce sum on a list of input arrays. + + For each array in test_inputs, check that performing reduce sum on the array + produces a value that is close to the expected result. + + Args: + expected_result: the expected result. + dtype: the data type of the reduce sum operation. + test_inputs: a list of input arrays for the reduce sum operation. + rtol: the relative error. + atol: the absolute error. + """ + + for test_input in test_inputs: + with self.test_session() as sess: + with self.test_scope(): + a = array_ops.placeholder(dtype) + index = array_ops.placeholder(dtypes.int32) + out = math_ops.reduce_sum(a, index) + result = sess.run(out, { + a: np.array(test_input, dtype=dtype), + index: [0] + }) + # Compare the results using float32 type. + self.assertAllClose( + np.float32(result), + np.float32(expected_result), + rtol=rtol, + atol=atol) + + def testReduceSumF16(self): + """Tests the reduce sum of float16 doesn't lose too much precision.""" + + if np.float16 not in self.all_types: + return + + f16_max = np.finfo(np.float16).max + self._testReduceSum( + f16_max, np.float16, + itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3)) + + def testReduceSumBF16(self): + """Tests the reduce sum of bfloat16 doesn't lose too much precision.""" + + if dtypes.bfloat16.as_numpy_dtype not in self.all_types: + return + + bf16_max = np.float32(dtypes.bfloat16.max) + f32_max = dtypes.float32.max + value = min(bf16_max, f32_max - bf16_max) + self._testReduceSum( + dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, + itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 62a5114837..a3deb02a1f 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -278,7 +278,7 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { - if (dtype == DT_BFLOAT16) { + if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } return dtype; -- cgit v1.2.3