aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/test_utils.cc
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-04-03 19:20:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 19:23:35 -0700
commitede6e1ff31531cae98844676af4981a821760188 (patch)
tree22cbf8e38257fa5487a16384532a9852b9aa3513 /tensorflow/compiler/xla/tests/test_utils.cc
parent62d547aa53dd0e7f53e8544b3c41d9274727d333 (diff)
[TF:XLA] Add half precision support to test_utils.
PiperOrigin-RevId: 191535944
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc25
1 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 0bc7df2a65..821432ef7d 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -23,14 +23,14 @@ namespace xla {
namespace {
-template <typename FloatT>
-void PopulateWithRandomFloatingPointData(Literal* literal,
- std::minstd_rand0* engine) {
+template <typename FloatT, typename GeneratorT>
+void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
+ std::minstd_rand0* engine) {
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
// Create uniform numbers between 1 and 1.125 to avoid creating denormal
// numbers.
- std::uniform_real_distribution<FloatT> generator(1.0f, 1.125f);
+ std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f);
const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000;
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> indices) {
@@ -52,10 +52,22 @@ void PopulateWithRandomFloatingPointData(Literal* literal,
FloatT index_bias =
static_cast<FloatT>(index_product % 113 - negative_bias) /
static_cast<FloatT>(256.0f);
- return (generator(*engine) - 1.0625) + index_bias;
+ return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias;
}));
}
+template <typename FloatT>
+void PopulateWithRandomFloatingPointData(Literal* literal,
+ std::minstd_rand0* engine) {
+ PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine);
+}
+
+template <>
+void PopulateWithRandomFloatingPointData<half>(Literal* literal,
+ std::minstd_rand0* engine) {
+ PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine);
+}
+
// The standard library does not have a case for bfloat16, unsurprisingly, so we
// handle that one specially.
template <>
@@ -100,6 +112,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
case BF16:
PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine);
break;
+ case F16:
+ PopulateWithRandomFloatingPointData<half>(literal.get(), engine);
+ break;
case F32:
PopulateWithRandomFloatingPointData<float>(literal.get(), engine);
break;