/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { class UnaryOpTest : public ClientLibraryTestBase { protected: template T inf() { return std::numeric_limits::infinity(); } template void AbsSize0TestHelper() { XlaBuilder builder(TestName()); auto arg = ConstantR1(&builder, {}); Abs(arg); if (primitive_util::NativeToPrimitiveType() == C64) { ComputeAndCompareR1(&builder, {}, {}); } else { ComputeAndCompareR1(&builder, {}, {}); } } template void AbsTestHelper() { XlaBuilder builder(TestName()); auto arg = ConstantR1(&builder, {-2, 25, 0, -123, inf(), -inf()}); Abs(arg); ComputeAndCompareR1(&builder, {2, 25, 0, 123, inf(), inf()}, {}); } template void SignTestHelper() { XlaBuilder builder(TestName()); auto arg = ConstantR1( &builder, {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); Sign(arg); ComputeAndCompareR1(&builder, {-1, 1, 0, 0, -1, 1, -1}, {}); } template void SignAbsTestHelper() { XlaBuilder builder(TestName()); auto arg = ConstantR1(&builder, {-2, 25, 0, -123}); auto sign = Sign(arg); auto abs = Abs(arg); Sub(Mul(sign, abs), arg); ComputeAndCompareR1(&builder, {0, 0, 0, 0}, {}); } }; template <> int UnaryOpTest::inf() { return 2147483647; } template <> int64 UnaryOpTest::inf() { return 0x7FFFFFFFFFFFFFFFl; } template <> void UnaryOpTest::AbsTestHelper() { XlaBuilder builder(TestName()); auto arg = ConstantR1(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.3f, 0.4f}, {0, inf()}, {-inf(), 0}}); Abs(arg); Literal expected = LiteralUtil::CreateR1({2, 25, 0, 0.5, inf(), inf()}); ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> void UnaryOpTest::SignTestHelper() { XlaBuilder builder(TestName()); auto arg = ConstantR1( &builder, {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); Sign(arg); Literal expected = LiteralUtil::CreateR1( {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } template <> void UnaryOpTest::SignAbsTestHelper() { XlaBuilder builder(TestName()); auto arg = ConstantR1(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); auto sign = Sign(arg); auto abs = Abs(arg); Sub(Mul(sign, ConvertElementType(abs, C64)), arg); Literal expected = LiteralUtil::CreateR1({0, 0, 0, 0}); ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { AbsSize0TestHelper(); AbsSize0TestHelper(); AbsSize0TestHelper(); } XLA_TEST_F(UnaryOpTest, AbsTestR1) { AbsTestHelper(); AbsTestHelper(); AbsTestHelper(); } XLA_TEST_F(UnaryOpTest, AbsTestR0) { XlaBuilder builder(TestName()); auto argi = ConstantR0(&builder, -5); auto absi = Abs(argi); auto argf = ConstantR0(&builder, -3.0f); auto absf = Abs(argf); auto argf0 = ConstantR0(&builder, -0.0f); auto absf0 = Abs(argf0); auto argc = ConstantR0(&builder, {-0.3f, 0.4f}); auto absc = Abs(argc); Add(Add(absc, absf0), Add(absf, ConvertElementType(absi, F32))); ComputeAndCompareR0(&builder, 8.5f, {}); } XLA_TEST_F(UnaryOpTest, SignTestR0) { XlaBuilder builder(TestName()); auto argi = ConstantR0(&builder, -5); auto sgni = Sign(argi); // -1 auto argf = ConstantR0(&builder, -4.0f); auto sgnf = Sign(argf); // -1 auto argf0 = ConstantR0(&builder, -0.0f); auto sgnf0 = Sign(argf0); // 0 auto argc = ConstantR0(&builder, {-.3, .4}); auto sgnc = Sign(argc); // (-.6, .8) Add(sgnc, ConvertElementType( Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64)); Literal expected = LiteralUtil::CreateR0({-2.6f, 0.8f}); ComputeAndCompareLiteral(&builder, expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); SignTestHelper(); SignTestHelper(); SignTestHelper(); } XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper(); SignAbsTestHelper(); SignAbsTestHelper(); } XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { XlaBuilder builder(TestName()); auto arg = ConstantR2(&builder, {{1.0, -2.0}, {-3.0, 4.0}}); auto sign = Sign(arg); auto abs = Abs(arg); Sub(Mul(sign, abs), arg); ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 1}); auto rhs = ConstantR1(&builder, {1, 1}); ConvertElementType(Eq(lhs, rhs), S32); ComputeAndCompareR1(&builder, {0, 1}, {}); } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 1}); auto rhs = ConstantR1(&builder, {1, 1}); ConvertElementType(Eq(lhs, rhs), F32); ComputeAndCompareR1(&builder, {0.0, 1.0}, {}); } } // namespace } // namespace xla