aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/half_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/half_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc94
1 files changed, 40 insertions, 54 deletions
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index 76bf47845c..249a4b2493 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase {
static const int kNumElements = 4;
};
-using UnaryBuildFuncTy =
- std::function<void(xla::XlaBuilder*, const xla::XlaOp& src)>;
+using UnaryBuildFuncTy = std::function<void(const xla::XlaOp& src)>;
struct UnaryOpTestParam {
std::function<half(half)> compute_func;
@@ -49,7 +48,8 @@ class UnaryOpTest : public HalfTestBase,
public ::testing::WithParamInterface<UnaryOpTestParam> {};
XLA_TEST_P(UnaryOpTest, Ops) {
- std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1)});
+ std::vector<half> x({half(1.4), half(-2.3), half(3.2), half(-4.1), half(9.0),
+ half(42.0), half(-9.0), half(-100.0)});
XlaBuilder builder(TestName());
XlaOp x_opnd;
auto x_data = CreateR1Parameter<half>(x, /*parameter_number=*/0, "x",
@@ -62,7 +62,7 @@ XLA_TEST_P(UnaryOpTest, Ops) {
}
UnaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd);
+ build_func(x_opnd);
ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_);
}
@@ -79,18 +79,17 @@ half round_imp(half value) {
INSTANTIATE_TEST_CASE_P(
half, UnaryOpTest,
::testing::Values(
- UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs},
- UnaryOpTestParam{[](half x) { return round_imp(x); },
- &XlaBuilder::Round},
- UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil},
- UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos},
- UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp},
- UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor},
- UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log},
- UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg},
- UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign},
- UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin},
- UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh}
+ UnaryOpTestParam{[](half x) { return abs(x); }, &Abs},
+ UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round},
+ UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil},
+ UnaryOpTestParam{[](half x) { return cos(x); }, &Cos},
+ UnaryOpTestParam{[](half x) { return exp(x); }, &Exp},
+ UnaryOpTestParam{[](half x) { return floor(x); }, &Floor},
+ UnaryOpTestParam{[](half x) { return log(x); }, &Log},
+ UnaryOpTestParam{[](half x) { return -x; }, &Neg},
+ UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign},
+ UnaryOpTestParam{[](half x) { return sin(x); }, &Sin},
+ UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh}
));
@@ -118,19 +117,18 @@ XLA_TEST_P(UnaryPredTest, Ops) {
}
UnaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd);
+ build_func(x_opnd);
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()});
}
INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
::testing::Values(UnaryPredTestParam{
- [](half x) { return isfinite(x); },
- &XlaBuilder::IsFinite}));
+ [](half x) { return isfinite(x); }, &IsFinite}));
-using BinaryBuildFuncTy = std::function<void(
- xla::XlaBuilder*, const xla::XlaOp& x, const xla::XlaOp& y,
- tensorflow::gtl::ArraySlice<int64>)>;
+using BinaryBuildFuncTy =
+ std::function<void(const xla::XlaOp& x, const xla::XlaOp& y,
+ tensorflow::gtl::ArraySlice<int64>)>;
struct BinaryOpTestParam {
std::function<half(half, half)> compute_func;
@@ -159,7 +157,7 @@ XLA_TEST_P(BinaryOpTest, Ops) {
}
BinaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd, y_opnd, {});
+ build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()},
error_spec_);
@@ -173,22 +171,15 @@ half atan2_imp(half x, half y) {
INSTANTIATE_TEST_CASE_P(
half, BinaryOpTest,
::testing::Values(
- BinaryOpTestParam{[](half x, half y) { return x + y; },
- &XlaBuilder::Add},
+ BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add},
BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); },
- &XlaBuilder::Atan2},
- BinaryOpTestParam{[](half x, half y) { return x / y; },
- &XlaBuilder::Div},
- BinaryOpTestParam{[](half x, half y) { return max(x, y); },
- &XlaBuilder::Max},
- BinaryOpTestParam{[](half x, half y) { return min(x, y); },
- &XlaBuilder::Min},
- BinaryOpTestParam{[](half x, half y) { return x * y; },
- &XlaBuilder::Mul},
- BinaryOpTestParam{[](half x, half y) { return pow(x, y); },
- &XlaBuilder::Pow},
- BinaryOpTestParam{[](half x, half y) { return x - y; },
- &XlaBuilder::Sub}
+ &Atan2},
+ BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div},
+ BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max},
+ BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min},
+ BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul},
+ BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow},
+ BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub}
));
@@ -221,27 +212,22 @@ XLA_TEST_P(BinaryPredTest, Ops) {
}
BinaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd, y_opnd, {});
+ build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()});
}
INSTANTIATE_TEST_CASE_P(
half, BinaryPredTest,
- ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; },
- &XlaBuilder::Eq},
- BinaryPredTestParam{[](half x, half y) { return x != y; },
- &XlaBuilder::Ne},
- BinaryPredTestParam{[](half x, half y) { return x >= y; },
- &XlaBuilder::Ge},
- BinaryPredTestParam{[](half x, half y) { return x > y; },
- &XlaBuilder::Gt},
- BinaryPredTestParam{[](half x, half y) { return x <= y; },
- &XlaBuilder::Le},
- BinaryPredTestParam{[](half x, half y) { return x < y; },
- &XlaBuilder::Lt}
-
- ));
+ ::testing::Values(
+ BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq},
+ BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne},
+ BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge},
+ BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt},
+ BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le},
+ BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt}
+
+ ));
} // namespace
} // namespace xla