diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/convert_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/convert_test.cc | 86 |
1 files changed, 75 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 292942a49e..0fb6853e3f 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <array> #include <cstdint> #include <limits> #include <memory> @@ -52,13 +53,67 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { ComputeAndCompareR1<int32>(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1S32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<int32>(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector<uint32> expected = {42, 64}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1S32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<int32>(&builder, {42, 0, -64}); + ConvertElementType(a, PRED); + + std::array<bool, 3> expected = {true, false, true}; + ComputeAndCompareR1<bool>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<uint32>(&builder, {42, 64}); + ConvertElementType(a, U32); + + std::vector<uint32> expected = {42, 64}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1S32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<uint32>(&builder, {42, 64}); + ConvertElementType(a, S32); + + std::vector<int32> expected = {42, 64}; + ComputeAndCompareR1<int32>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1U32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<uint32>(&builder, {42, 0, 64}); + ConvertElementType(a, PRED); + + std::array<bool, 3> expected = {true, false, true}; + ComputeAndCompareR1<bool>(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1F32ToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1<float>(&builder, {42.0f, 64.0f}); ConvertElementType(a, F32); std::vector<float> expected = {42.0f, 64.0f}; - ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1<float>(&builder, expected, {}); +} + +TEST_F(ConvertTest, ConvertR1F32ToR1PRED) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<float>(&builder, {42.0f, 0.0f, 64.0f}); + ConvertElementType(a, PRED); + + std::array<bool, 3> expected = {true, false, true}; + ComputeAndCompareR1<bool>(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { @@ -67,7 +122,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { ConvertElementType(a, F32); std::vector<float> expected = {42.0f, 64.0f}; - ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1<float>(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { @@ -79,6 +134,15 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { ComputeAndCompareR1<int32>(&builder, expected, {}); } +TEST_F(ConvertTest, ConvertR1PREDToR1U32) { + XlaBuilder builder(TestName()); + auto a = ConstantR1<bool>(&builder, {true, false, true}); + ConvertElementType(a, U32); + + std::vector<uint32> expected = {1, 0, 1}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + TEST_F(ConvertTest, ConvertR1PREDToR1F32) { XlaBuilder builder(TestName()); auto a = ConstantR1<bool>(&builder, {true, false, true}); @@ -94,7 +158,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { ConvertElementType(a, F32); std::vector<float> expected = {}; - ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1<float>(&builder, expected, {}); } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { @@ -145,7 +209,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { static_cast<int64>(0x8000008000000000LL), static_cast<int64>(0x8000010000000000LL), }; - std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int64>({arg}); + std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr<GlobalData> arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -164,7 +228,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg}); + std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr<GlobalData> arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -182,7 +246,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XlaBuilder builder(TestName()); std::vector<float> arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; - std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg}); + std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr<GlobalData> arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -199,7 +263,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XlaBuilder builder(TestName()); std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; - std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg}); + std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr<GlobalData> arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -216,7 +280,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { XlaBuilder builder(TestName()); std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000}; - std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int32>({arg}); + std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr<GlobalData> arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -253,7 +317,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { 9223370937343148032.f, -9223371487098961920.f, -9223370937343148032.f}; - std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg}); + std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg}); auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param"); std::unique_ptr<GlobalData> arg_data = client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); @@ -391,7 +455,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1<half>(input))); + client_->TransferToServer(*LiteralUtil::CreateR1<half>(input))); XlaBuilder builder(TestName()); ConvertElementType( @@ -411,7 +475,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr<GlobalData> dot_lhs_handle, - client_->TransferToServer(*Literal::CreateR1<float>(input))); + client_->TransferToServer(*LiteralUtil::CreateR1<float>(input))); XlaBuilder builder(TestName()); ConvertElementType( |