aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/scalar_computations_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/scalar_computations_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc151
1 files changed, 71 insertions, 80 deletions
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index d0ebb108ae..5a3bcaf086 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -20,7 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -44,25 +45,26 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
protected:
// A template for building and running a binary comparison test.
template <typename NativeT>
- void TestCompare(
- NativeT lhs, NativeT rhs, bool expected,
- XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ void TestCompare(NativeT lhs, NativeT rhs, bool expected,
+ std::function<XlaOp(const XlaOp&, const XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
- (builder.*op)(lhs_op, rhs_op, {});
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
template <typename NativeT>
void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
- XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ std::function<XlaOp(const XlaOp&, const XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
- (builder.*op)(lhs_op, rhs_op, {});
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<NativeT>(&builder, expected, {});
}
};
@@ -161,7 +163,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ConvertElementType(a, F32);
int64 value = 3LL << 35;
- std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
@@ -225,9 +227,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR0<float>(2.1f);
- std::unique_ptr<Literal> b_literal = Literal::CreateR0<float>(5.5f);
- std::unique_ptr<Literal> c_literal = Literal::CreateR0<float>(0.5f);
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
@@ -374,8 +376,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
for (uint32 divisor : vals) {
if (divisor != 0) {
for (uint32 dividend : vals) {
- auto dividend_literal = Literal::CreateR0<uint32>(dividend);
- auto divisor_literal = Literal::CreateR0<uint32>(divisor);
+ auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
+ auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
client_->TransferToServer(*dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
@@ -386,7 +388,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
{dividend_data.get(), divisor_data.get()},
&execution_options_)
.ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
+ auto expected_literal =
+ LiteralUtil::CreateR0<uint32>(dividend / divisor);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
@@ -415,8 +418,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
for (uint32 divisor : vals) {
if (divisor != 0) {
for (uint32 dividend : vals) {
- auto dividend_literal = Literal::CreateR0<uint32>(dividend);
- auto divisor_literal = Literal::CreateR0<uint32>(divisor);
+ auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
+ auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
client_->TransferToServer(*dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
@@ -427,7 +430,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
{dividend_data.get(), divisor_data.get()},
&execution_options_)
.ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
+ auto expected_literal =
+ LiteralUtil::CreateR0<uint32>(dividend % divisor);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
@@ -439,7 +443,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
Rem(x, ConstantR0<int32>(&builder, 80000));
- std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(87919);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
}
@@ -583,117 +587,116 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
// S32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
- TestCompare<int32>(2, 1, false, &XlaBuilder::Eq);
+ TestCompare<int32>(2, 1, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
- TestCompare<int32>(3, 3, true, &XlaBuilder::Eq);
+ TestCompare<int32>(3, 3, true, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
- TestCompare<int32>(2, 1, true, &XlaBuilder::Ne);
+ TestCompare<int32>(2, 1, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
- TestCompare<int32>(2, 1, true, &XlaBuilder::Ge);
+ TestCompare<int32>(2, 1, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
- TestCompare<int32>(1, 5, false, &XlaBuilder::Gt);
+ TestCompare<int32>(1, 5, false, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
- TestCompare<int32>(2, 1, false, &XlaBuilder::Le);
+ TestCompare<int32>(2, 1, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
- TestCompare<int32>(9, 7, false, &XlaBuilder::Lt);
+ TestCompare<int32>(9, 7, false, &Lt);
TestCompare<int32>(std::numeric_limits<int32>::min(),
- std::numeric_limits<int32>::max(), true, &XlaBuilder::Lt);
+ std::numeric_limits<int32>::max(), true, &Lt);
}
// U32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
- TestCompare<uint32>(2, 1, false, &XlaBuilder::Eq);
+ TestCompare<uint32>(2, 1, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
- TestCompare<uint32>(2, 1, true, &XlaBuilder::Ne);
+ TestCompare<uint32>(2, 1, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
- TestCompare<uint32>(2, 1, true, &XlaBuilder::Ge);
+ TestCompare<uint32>(2, 1, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
- TestCompare<uint32>(3, 3, true, &XlaBuilder::Ge);
+ TestCompare<uint32>(3, 3, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
- TestCompare<uint32>(1, 5, false, &XlaBuilder::Gt);
- TestCompare<uint32>(5, 5, false, &XlaBuilder::Gt);
- TestCompare<uint32>(5, 1, true, &XlaBuilder::Gt);
+ TestCompare<uint32>(1, 5, false, &Gt);
+ TestCompare<uint32>(5, 5, false, &Gt);
+ TestCompare<uint32>(5, 1, true, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
- TestCompare<uint32>(2, 1, false, &XlaBuilder::Le);
+ TestCompare<uint32>(2, 1, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
- TestCompare<uint32>(9, 7, false, &XlaBuilder::Lt);
- TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
- &XlaBuilder::Lt);
+ TestCompare<uint32>(9, 7, false, &Lt);
+ TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, &Lt);
}
// F32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
- TestCompare<float>(2.0, 1.3, false, &XlaBuilder::Eq);
+ TestCompare<float>(2.0, 1.3, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
- TestCompare<float>(2.0, 1.3, true, &XlaBuilder::Ne);
+ TestCompare<float>(2.0, 1.3, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
- TestCompare<float>(2.0, 1.9, true, &XlaBuilder::Ge);
+ TestCompare<float>(2.0, 1.9, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
- TestCompare<float>(3.5, 3.5, true, &XlaBuilder::Ge);
+ TestCompare<float>(3.5, 3.5, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
- TestCompare<float>(1.0, 5.2, false, &XlaBuilder::Gt);
+ TestCompare<float>(1.0, 5.2, false, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
- TestCompare<float>(2.0, 1.2, false, &XlaBuilder::Le);
+ TestCompare<float>(2.0, 1.2, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
- TestCompare<float>(9.0, 7.2, false, &XlaBuilder::Lt);
+ TestCompare<float>(9.0, 7.2, false, &Lt);
}
// F32 comparisons with exceptional values. The test names encode the
// left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
- TestCompare<float>(-INFINITY, -0.0, true, &XlaBuilder::Lt);
+ TestCompare<float>(-INFINITY, -0.0, true, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
- TestCompare<float>(-0.0, 0.0, false, &XlaBuilder::Lt);
+ TestCompare<float>(-0.0, 0.0, false, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
- TestCompare<float>(0.0, INFINITY, true, &XlaBuilder::Lt);
+ TestCompare<float>(0.0, INFINITY, true, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
- TestCompare<float>(-INFINITY, -0.0, false, &XlaBuilder::Ge);
+ TestCompare<float>(-INFINITY, -0.0, false, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
- TestCompare<float>(-0.0, 0.0, true, &XlaBuilder::Ge);
+ TestCompare<float>(-0.0, 0.0, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
- TestCompare<float>(0.0, INFINITY, false, &XlaBuilder::Ge);
+ TestCompare<float>(0.0, INFINITY, false, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
@@ -813,65 +816,65 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) {
}
XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
- TestMinMax<int32>(10, 3, 3, &XlaBuilder::Min);
+ TestMinMax<int32>(10, 3, 3, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
- TestMinMax<int32>(-100, 3, -100, &XlaBuilder::Min);
+ TestMinMax<int32>(-100, 3, -100, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
- TestMinMax<int32>(10, 3, 10, &XlaBuilder::Max);
+ TestMinMax<int32>(10, 3, 10, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
- TestMinMax<int32>(-100, 3, 3, &XlaBuilder::Max);
+ TestMinMax<int32>(-100, 3, 3, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
const uint32 large = std::numeric_limits<int32>::max();
- TestMinMax<uint32>(large, 3, 3, &XlaBuilder::Min);
+ TestMinMax<uint32>(large, 3, 3, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
- TestMinMax<uint32>(0, 5, 0, &XlaBuilder::Min);
+ TestMinMax<uint32>(0, 5, 0, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
const uint32 large = std::numeric_limits<int32>::max();
- TestMinMax<uint32>(large, 3, large, &XlaBuilder::Max);
+ TestMinMax<uint32>(large, 3, large, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
- TestMinMax<uint32>(0, 5, 5, &XlaBuilder::Max);
+ TestMinMax<uint32>(0, 5, 5, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
- TestMinMax<float>(10.1f, 3.1f, 3.1f, &XlaBuilder::Min);
+ TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
- TestMinMax<float>(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min);
+ TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) {
SetFastMathDisabled(true);
- TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Min);
- TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Min);
+ TestMinMax<float>(NAN, 3.1f, NAN, &Min);
+ TestMinMax<float>(-3.1f, NAN, NAN, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
- TestMinMax<float>(10.1f, 3.1f, 10.1f, &XlaBuilder::Max);
+ TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
- TestMinMax<float>(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max);
+ TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) {
SetFastMathDisabled(true);
- TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Max);
- TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Max);
+ TestMinMax<float>(NAN, 3.1f, NAN, &Max);
+ TestMinMax<float>(-3.1f, NAN, NAN, &Max);
}
XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
@@ -897,18 +900,6 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
ComputeAndCompareR0<int32>(&b, 10, {});
}
-XLA_TEST_F(ScalarComputationsTest, SqrtF320) {
- XlaBuilder builder(TestName());
- Literal zero_literal = Literal::Zero(PrimitiveType::F32);
-
- std::unique_ptr<GlobalData> zero_data =
- client_->TransferToServer(zero_literal).ConsumeValueOrDie();
-
- XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero");
- SqrtF32(zero);
-
- ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
-}
XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
XlaBuilder builder(TestName());