aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/convert_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/convert_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc86
1 files changed, 75 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 292942a49e..0fb6853e3f 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <array>
#include <cstdint>
#include <limits>
#include <memory>
@@ -52,13 +53,67 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
ComputeAndCompareR1<int32>(&builder, expected, {});
}
+TEST_F(ConvertTest, ConvertR1S32ToR1U32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<int32>(&builder, {42, 64});
+ ConvertElementType(a, U32);
+
+ std::vector<uint32> expected = {42, 64};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1S32ToR1PRED) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<int32>(&builder, {42, 0, -64});
+ ConvertElementType(a, PRED);
+
+ std::array<bool, 3> expected = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U32ToR1U32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {42, 64});
+ ConvertElementType(a, U32);
+
+ std::vector<uint32> expected = {42, 64};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U32ToR1S32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {42, 64});
+ ConvertElementType(a, S32);
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1U32ToR1PRED) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {42, 0, 64});
+ ConvertElementType(a, PRED);
+
+ std::array<bool, 3> expected = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1F32ToR1PRED) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<float>(&builder, {42.0f, 0.0f, 64.0f});
+ ConvertElementType(a, PRED);
+
+ std::array<bool, 3> expected = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
@@ -67,7 +122,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
@@ -79,6 +134,15 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
ComputeAndCompareR1<int32>(&builder, expected, {});
}
+TEST_F(ConvertTest, ConvertR1PREDToR1U32) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<bool>(&builder, {true, false, true});
+ ConvertElementType(a, U32);
+
+ std::vector<uint32> expected = {1, 0, 1};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
XlaBuilder builder(TestName());
auto a = ConstantR1<bool>(&builder, {true, false, true});
@@ -94,7 +158,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
ConvertElementType(a, F32);
std::vector<float> expected = {};
- ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareR1<float>(&builder, expected, {});
}
TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
@@ -145,7 +209,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int64>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -164,7 +228,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -182,7 +246,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -199,7 +263,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -216,7 +280,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int32>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -253,7 +317,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
@@ -391,7 +455,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*Literal::CreateR1<half>(input)));
+ client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
ConvertElementType(
@@ -411,7 +475,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*Literal::CreateR1<float>(input)));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
ConvertElementType(