diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/scalar_computations_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/scalar_computations_test.cc | 151 |
1 files changed, 71 insertions, 80 deletions
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index d0ebb108ae..5a3bcaf086 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -20,7 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" -#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -44,25 +45,26 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template <typename NativeT> - void TestCompare( - NativeT lhs, NativeT rhs, bool expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice<int64>)) { + void TestCompare(NativeT lhs, NativeT rhs, bool expected, + std::function<XlaOp(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice<int64>)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs); XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0<bool>(&builder, expected, {}); } template <typename NativeT> void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, - tensorflow::gtl::ArraySlice<int64>)) { + std::function<XlaOp(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice<int64>)> + op) { XlaBuilder builder(TestName()); XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs); XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs); - (builder.*op)(lhs_op, rhs_op, {}); + op(lhs_op, rhs_op, {}); ComputeAndCompareR0<NativeT>(&builder, expected, {}); } }; @@ -161,7 +163,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { ConvertElementType(a, F32); int64 value = 3LL << 35; - std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value); + std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value); std::unique_ptr<GlobalData> a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); ComputeAndCompareR0<float>(&builder, static_cast<float>(value), @@ -225,9 +227,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { XlaBuilder builder(TestName()); - std::unique_ptr<Literal> a_literal = Literal::CreateR0<float>(2.1f); - std::unique_ptr<Literal> b_literal = Literal::CreateR0<float>(5.5f); - std::unique_ptr<Literal> c_literal = Literal::CreateR0<float>(0.5f); + std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f); + std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f); + std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f); std::unique_ptr<GlobalData> a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -374,8 +376,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0<uint32>(dividend); - auto divisor_literal = Literal::CreateR0<uint32>(divisor); + auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend); + auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -386,7 +388,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor); + auto expected_literal = + LiteralUtil::CreateR0<uint32>(dividend / divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -415,8 +418,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = Literal::CreateR0<uint32>(dividend); - auto divisor_literal = Literal::CreateR0<uint32>(divisor); + auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend); + auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor); TF_ASSERT_OK_AND_ASSIGN(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSERT_OK_AND_ASSIGN(auto divisor_data, @@ -427,7 +430,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor); + auto expected_literal = + LiteralUtil::CreateR0<uint32>(dividend % divisor); EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } @@ -439,7 +443,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x"); Rem(x, ConstantR0<int32>(&builder, 80000)); - std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(87919); + std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()}); } @@ -583,117 +587,116 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare<int32>(2, 1, false, &XlaBuilder::Eq); + TestCompare<int32>(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare<int32>(3, 3, true, &XlaBuilder::Eq); + TestCompare<int32>(3, 3, true, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare<int32>(2, 1, true, &XlaBuilder::Ne); + TestCompare<int32>(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare<int32>(2, 1, true, &XlaBuilder::Ge); + TestCompare<int32>(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare<int32>(1, 5, false, &XlaBuilder::Gt); + TestCompare<int32>(1, 5, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare<int32>(2, 1, false, &XlaBuilder::Le); + TestCompare<int32>(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare<int32>(9, 7, false, &XlaBuilder::Lt); + TestCompare<int32>(9, 7, false, &Lt); TestCompare<int32>(std::numeric_limits<int32>::min(), - std::numeric_limits<int32>::max(), true, &XlaBuilder::Lt); + std::numeric_limits<int32>::max(), true, &Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare<uint32>(2, 1, false, &XlaBuilder::Eq); + TestCompare<uint32>(2, 1, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare<uint32>(2, 1, true, &XlaBuilder::Ne); + TestCompare<uint32>(2, 1, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare<uint32>(2, 1, true, &XlaBuilder::Ge); + TestCompare<uint32>(2, 1, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare<uint32>(3, 3, true, &XlaBuilder::Ge); + TestCompare<uint32>(3, 3, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare<uint32>(1, 5, false, &XlaBuilder::Gt); - TestCompare<uint32>(5, 5, false, &XlaBuilder::Gt); - TestCompare<uint32>(5, 1, true, &XlaBuilder::Gt); + TestCompare<uint32>(1, 5, false, &Gt); + TestCompare<uint32>(5, 5, false, &Gt); + TestCompare<uint32>(5, 1, true, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare<uint32>(2, 1, false, &XlaBuilder::Le); + TestCompare<uint32>(2, 1, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare<uint32>(9, 7, false, &XlaBuilder::Lt); - TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, - &XlaBuilder::Lt); + TestCompare<uint32>(9, 7, false, &Lt); + TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, &Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare<float>(2.0, 1.3, false, &XlaBuilder::Eq); + TestCompare<float>(2.0, 1.3, false, &Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare<float>(2.0, 1.3, true, &XlaBuilder::Ne); + TestCompare<float>(2.0, 1.3, true, &Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare<float>(2.0, 1.9, true, &XlaBuilder::Ge); + TestCompare<float>(2.0, 1.9, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare<float>(3.5, 3.5, true, &XlaBuilder::Ge); + TestCompare<float>(3.5, 3.5, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare<float>(1.0, 5.2, false, &XlaBuilder::Gt); + TestCompare<float>(1.0, 5.2, false, &Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare<float>(2.0, 1.2, false, &XlaBuilder::Le); + TestCompare<float>(2.0, 1.2, false, &Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare<float>(9.0, 7.2, false, &XlaBuilder::Lt); + TestCompare<float>(9.0, 7.2, false, &Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare<float>(-INFINITY, -0.0, true, &XlaBuilder::Lt); + TestCompare<float>(-INFINITY, -0.0, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare<float>(-0.0, 0.0, false, &XlaBuilder::Lt); + TestCompare<float>(-0.0, 0.0, false, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare<float>(0.0, INFINITY, true, &XlaBuilder::Lt); + TestCompare<float>(0.0, INFINITY, true, &Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare<float>(-INFINITY, -0.0, false, &XlaBuilder::Ge); + TestCompare<float>(-INFINITY, -0.0, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare<float>(-0.0, 0.0, true, &XlaBuilder::Ge); + TestCompare<float>(-0.0, 0.0, true, &Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare<float>(0.0, INFINITY, false, &XlaBuilder::Ge); + TestCompare<float>(0.0, INFINITY, false, &Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { @@ -813,65 +816,65 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax<int32>(10, 3, 3, &XlaBuilder::Min); + TestMinMax<int32>(10, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax<int32>(-100, 3, -100, &XlaBuilder::Min); + TestMinMax<int32>(-100, 3, -100, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax<int32>(10, 3, 10, &XlaBuilder::Max); + TestMinMax<int32>(10, 3, 10, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax<int32>(-100, 3, 3, &XlaBuilder::Max); + TestMinMax<int32>(-100, 3, 3, &Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits<int32>::max(); - TestMinMax<uint32>(large, 3, 3, &XlaBuilder::Min); + TestMinMax<uint32>(large, 3, 3, &Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax<uint32>(0, 5, 0, &XlaBuilder::Min); + TestMinMax<uint32>(0, 5, 0, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits<int32>::max(); - TestMinMax<uint32>(large, 3, large, &XlaBuilder::Max); + TestMinMax<uint32>(large, 3, large, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax<uint32>(0, 5, 5, &XlaBuilder::Max); + TestMinMax<uint32>(0, 5, 5, &Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax<float>(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); + TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax<float>(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); + TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Min); - TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Min); + TestMinMax<float>(NAN, 3.1f, NAN, &Min); + TestMinMax<float>(-3.1f, NAN, NAN, &Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax<float>(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); + TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax<float>(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); + TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Max); - TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Max); + TestMinMax<float>(NAN, 3.1f, NAN, &Max); + TestMinMax<float>(-3.1f, NAN, NAN, &Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { @@ -897,18 +900,6 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { ComputeAndCompareR0<int32>(&b, 10, {}); } -XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - XlaBuilder builder(TestName()); - Literal zero_literal = Literal::Zero(PrimitiveType::F32); - - std::unique_ptr<GlobalData> zero_data = - client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - - XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero"); - SqrtF32(zero); - - ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_); -} XLA_TEST_F(ScalarComputationsTest, RoundScalar) { XlaBuilder builder(TestName()); |