diff options
author | Bixia Zheng <bixia@google.com> | 2018-04-03 19:20:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-03 19:23:35 -0700 |
commit | ede6e1ff31531cae98844676af4981a821760188 (patch) | |
tree | 22cbf8e38257fa5487a16384532a9852b9aa3513 /tensorflow/compiler/xla/tests/test_utils.cc | |
parent | 62d547aa53dd0e7f53e8544b3c41d9274727d333 (diff) |
[TF:XLA] Add half precision support to test_utils.
PiperOrigin-RevId: 191535944
Diffstat (limited to 'tensorflow/compiler/xla/tests/test_utils.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/test_utils.cc | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0bc7df2a65..821432ef7d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -23,14 +23,14 @@ namespace xla { namespace { -template <typename FloatT> -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { +template <typename FloatT, typename GeneratorT> +void PopulateWithRandomFloatingPointDataImpl(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType<FloatT>()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal // numbers. - std::uniform_real_distribution<FloatT> generator(1.0f, 1.125f); + std::uniform_real_distribution<GeneratorT> generator(1.0f, 1.125f); const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate<FloatT>( [&](tensorflow::gtl::ArraySlice<int64> indices) { @@ -52,10 +52,22 @@ void PopulateWithRandomFloatingPointData(Literal* literal, FloatT index_bias = static_cast<FloatT>(index_product % 113 - negative_bias) / static_cast<FloatT>(256.0f); - return (generator(*engine) - 1.0625) + index_bias; + return static_cast<FloatT>(generator(*engine) - 1.0625f) + index_bias; })); } +template <typename FloatT> +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine); +} + +template <> +void PopulateWithRandomFloatingPointData<half>(Literal* literal, + std::minstd_rand0* engine) { + PopulateWithRandomFloatingPointDataImpl<half, float>(literal, engine); +} + // The standard library does not have a case for bfloat16, unsurprisingly, so we // handle that one specially. template <> @@ -100,6 +112,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal( case BF16: PopulateWithRandomFloatingPointData<bfloat16>(literal.get(), engine); break; + case F16: + PopulateWithRandomFloatingPointData<half>(literal.get(), engine); + break; case F32: PopulateWithRandomFloatingPointData<float>(literal.get(), engine); break; |