/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include #include #include #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { class ConvertTest : public ClientLibraryTestBase { public: explicit ConvertTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); } }; TEST_F(ConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1S32ToR1U32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); ConvertElementType(a, U32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1S32ToR1PRED) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 0, -64}); ConvertElementType(a, PRED); std::array expected = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1U32ToR1U32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); ConvertElementType(a, U32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1U32ToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1U32ToR1PRED) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 0, 64}); ConvertElementType(a, PRED); std::array expected = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42.0f, 64.0f}); ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1F32ToR1PRED) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42.0f, 0.0f, 64.0f}); ConvertElementType(a, PRED); std::array expected = {true, false, true}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); ConvertElementType(a, F32); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {true, false, true}); ConvertElementType(a, S32); std::vector expected = {1, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1PREDToR1U32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {true, false, true}); ConvertElementType(a, U32); std::vector expected = {1, 0, 1}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {true, false, true}); ConvertElementType(a, F32); std::vector expected = {1., 0., 1.}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); ConvertElementType(a, F32); std::vector expected = {}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42.6, 64.4}); ConvertElementType(a, S32); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { XlaBuilder builder(TestName()); std::vector arg{ -9223371216516022272, -2, -1, -0x7FFFFFFF, -0x80000000, 0, 1, 2, 1073742145, 1073742656, 0x7FFFFFFF, 0x80000000, 826720496944058148, 4296062029846194332, 0x0007FB72E4000000LL, 0x0007FB72E4000001LL, 0x0007FB72E6000000LL, 0x0007FB72E7000000LL, 0x0007FB72E7FFFFFFLL, 0x0007FB72E8000000LL, 0x0007FB72E8000001LL, 0x0007FB72EA000000LL, 0x0007FB72EB000000LL, 0x0007FB72EBFFFFFFLL, 0x0007FB72EC000000LL, 0x7FFFFF0000000000LL, 0x7FFFFF8000000000LL, 0x7FFFFFFFFFFFFF00, static_cast(0xFFFFFFFFFFFFFFFF), static_cast(0x0000f234e67e0001LL), static_cast(0x8000000000000000), static_cast(0x8000000000000000LL), static_cast(0x8000000000000001LL), static_cast(0x8000008000000000LL), static_cast(0x8000010000000000LL), }; Literal arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { expected[i] = static_cast(arg[i]); } ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; Literal arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, F32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { expected[i] = static_cast(arg[i]); } ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; Literal arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, U32); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { expected[i] = static_cast(arg[i]); } ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; Literal arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { expected[i] = static_cast(arg[i]); } ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; Literal arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { expected[i] = static_cast(arg[i]); } ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { XlaBuilder builder(TestName()); // Test cases from compiler_rt library. std::vector arg{0.0f, 0.5f, 0.99f, 1.0f, 1.5f, 1.99f, 2.0f, 2.01f, 2147483648.f, -0.5f, -0.99f, -1.0f, -1.5f, -1.99f, -2.0f, -2.01f, 9223371487098961920.f, 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; Literal arg_literal = LiteralUtil::CreateR1({arg}); auto arg_param = Parameter(&builder, 0, arg_literal.shape(), "arg_param"); std::unique_ptr arg_data = client_->TransferToServer(arg_literal).ConsumeValueOrDie(); ConvertElementType(arg_param, S64); std::vector expected(arg.size()); for (int64 i = 0; i < arg.size(); ++i) { expected[i] = static_cast(arg[i]); } ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {32, 64}); ConvertElementType(a, F32); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {32, 64}); ConvertElementType(a, S32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {32, 64}); ConvertElementType(a, U32); std::vector expected = {32, 64}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {32.0f, 64.0f}); ConvertElementType(a, F64); std::vector expected = {32.0, 64.0}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {32.0, 64.0}); ConvertElementType(a, F32); std::vector expected = {32.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertS32Extremes) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {std::numeric_limits::min(), std::numeric_limits::max()}); ConvertElementType(a, F32); std::vector expected = { static_cast(std::numeric_limits::min()), static_cast(std::numeric_limits::max())}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } TEST_F(ConvertTest, ConvertMapToS32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in"); ConvertElementType(param, S32); auto a = ConstantR1(&builder, {42.0f, 64.0f}); Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); } TEST_F(ConvertTest, ConvertMapToF32) { XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in"); ConvertElementType(param, F32); auto a = ConstantR1(&builder, {42, 64}); Map(&builder, {a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); } // Regression test for b/31758660. When ReshapeMover transforms // input -> reshape -> convert // to // input -> convert -> reshape // the new convert should have the same element type as the old convert. TEST_F(ConvertTest, ConvertReshape) { XlaBuilder builder(TestName()); auto input = ConstantR1(&builder, {42}); auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); ConvertElementType(reshape, F32); ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); } std::vector GetInterestingF16ConversionTestCases() { float infinity = std::numeric_limits::infinity(); float half_min_positive_normal = tensorflow::bit_cast(0x38800000); float half_max_subnormal = tensorflow::bit_cast(0x387fc000); float half_min_positive_subnormal = tensorflow::bit_cast(0x33800000); float half_max = 65504.0f; std::vector test_cases( {-infinity, -(half_max * 2 + 1), -half_max, -42.0f, -1.0f, -half_min_positive_subnormal, -half_max_subnormal, -half_min_positive_normal, -0.0f, 0.0f, half_min_positive_subnormal, half_max_subnormal, half_min_positive_normal, 1.0f, 42.0f, half_max, (half_max * 2 + 1), infinity}); return test_cases; } XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { std::vector test_cases = GetInterestingF16ConversionTestCases(); std::vector input; absl::c_transform(test_cases, std::back_inserter(input), [](float f) { return Eigen::half(f); }); std::vector expected_output; absl::c_transform(input, std::back_inserter(expected_output), [](Eigen::half h) { return static_cast(h); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( Parameter(&builder, 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), "param"), F32); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { std::vector input = GetInterestingF16ConversionTestCases(); std::vector expected_output; absl::c_transform(input, std::back_inserter(expected_output), [](float f) { return Eigen::half(f); }); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr dot_lhs_handle, client_->TransferToServer(LiteralUtil::CreateR1(input))); XlaBuilder builder(TestName()); ConvertElementType( Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), "param"), F16); ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); } XLA_TEST_F(ConvertTest, ConvertC64ToC64) { XlaBuilder builder(TestName()); std::vector x = {{42.0f, 64.0f}}; ConvertElementType(ConstantR1(&builder, x), C64); ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConvertTest, ConvertS64S64) { XlaBuilder builder(TestName()); std::vector x = {{-42, 64}}; ConvertElementType(ConstantR1(&builder, x), S64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64U64) { XlaBuilder builder(TestName()); std::vector x = {{42, 64}}; ConvertElementType(ConstantR1(&builder, x), U64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64S64) { XlaBuilder builder(TestName()); std::vector unsigned_x = {{42, UINT64_MAX}}; ConvertElementType(ConstantR1(&builder, unsigned_x), S64); std::vector signed_x = {{42, -1}}; ComputeAndCompareR1(&builder, signed_x, {}); } XLA_TEST_F(ConvertTest, ConvertS64U64) { XlaBuilder builder(TestName()); std::vector signed_x = {{42, -1, INT64_MIN}}; ConvertElementType(ConstantR1(&builder, signed_x), U64); std::vector unsigned_x = { {42, UINT64_MAX, tensorflow::MathUtil::IPow(2, 63)}}; ComputeAndCompareR1(&builder, unsigned_x, {}); } XLA_TEST_F(ConvertTest, ConvertBF16F32) { XlaBuilder builder(TestName()); std::vector all_bfloats(1 << 16); for (int i = 0; i < all_bfloats.size(); ++i) { all_bfloats[i].value = i; } std::vector expected(all_bfloats.size()); for (int i = 0; i < expected.size(); ++i) { expected[i] = (1U << 16) * i; } // Exhaustively test all bf16 to f32 conversions. xla::XlaOp all_bfloats_bf16 = ConstantR1(&builder, all_bfloats); xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32); BitcastConvertType(all_bfloats_f32, U32); ComputeAndCompareR1(&builder, expected, {}); } } // namespace } // namespace xla