/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include #include #include #include "absl/types/span.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { class ArrayElementwiseOpTest : public ClientLibraryTestBase { public: ErrorSpec error_spec_{0.0001, 0.0001}; }; class ArrayElementwiseOpTestParamCount : public ArrayElementwiseOpTest, public ::testing::WithParamInterface {}; XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); Neg(a); ComputeAndCompareR1(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-1, 0, 1, 324, std::numeric_limits::min(), std::numeric_limits::max()}); Neg(a); // -min == min for int32 due to an overflow. In C++ it is undefined behavior // to do this calculation. For XLA we have not specified that, so it // ought to work. ComputeAndCompareR1(&builder, {1, 0, -1, -324, std::numeric_limits::min(), -std::numeric_limits::max()}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); Neg(a); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); Neg(a); ComputeAndCompareR1( &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, { -1, 1, 0, 0x12345678, static_cast(0xffffffff12345678l), static_cast(0x8000000000000000LL), static_cast(0x8000000000000001LL), }); Neg(a); LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); ComputeAndCompareR1(&builder, { 1, -1, 0, -0x12345678, 0xedcba988, static_cast(0x8000000000000000LL), -static_cast(0x8000000000000001LL), }, {}); } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); IsFinite(a); ComputeAndCompareR1(&builder, {}, {}); } // A non-canonical quiet NaN value. static const float kNonCanonicalNaN = tensorflow::bit_cast(0x7FD01234); XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) { XlaBuilder builder(TestName()); IsFinite(ConstantR0(&builder, NAN)); ComputeAndCompareR0(&builder, false, {}); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); IsFinite(ConstantR0(&builder, kNonCanonicalNaN)); ComputeAndCompareR0(&builder, false, {}); const float inf = std::numeric_limits::infinity(); IsFinite(ConstantR0(&builder, inf)); ComputeAndCompareR0(&builder, false, {}); IsFinite(ConstantR0(&builder, -inf)); ComputeAndCompareR0(&builder, false, {}); IsFinite(ConstantR0(&builder, 0.0f)); ComputeAndCompareR0(&builder, true, {}); } XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) { XlaBuilder builder(TestName()); const float inf = std::numeric_limits::infinity(); EXPECT_TRUE(std::isnan(kNonCanonicalNaN)); auto a = ConstantR1(&builder, {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}}); IsFinite(a); ComputeAndCompareR1(&builder, {false, true, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); Add(a, b); ComputeAndCompareR1(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); auto b = ConstantR1( &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); Add(a, b); ComputeAndCompareR1( &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Add(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { XlaBuilder b(TestName()); std::vector lhs{0xFFFFFFFF, static_cast(-1), 0, 0, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFLL, 0x8000000000000000LL, 0x8000000000000000LL, 1}; Literal lhs_literal = LiteralUtil::CreateR1({lhs}); auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{1, 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x8000000000000000LL, 0, static_cast(-1), 0, 1, 0x8000000000000000LL}; Literal rhs_literal = LiteralUtil::CreateR1({rhs}); auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Add(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { expected[i] = lhs[i] + rhs[i]; } ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { XlaBuilder b(TestName()); std::vector lhs{static_cast(0x8000000000000000LL), static_cast(0x8000000000000000LL), -1, 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 1, 0, -1}; Literal lhs_literal = LiteralUtil::CreateR1({lhs}); auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::unique_ptr lhs_data = client_->TransferToServer(lhs_literal).ConsumeValueOrDie(); std::vector rhs{-1, 0, static_cast(0x8000000000000000LL), 1, 0, 0x7FFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL, 0x7FFFFFFFFFFFFFFFLL}; Literal rhs_literal = LiteralUtil::CreateR1({rhs}); auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); std::unique_ptr rhs_data = client_->TransferToServer(rhs_literal).ConsumeValueOrDie(); Sub(lhs_param, rhs_param); std::vector expected(lhs.size()); for (int64 i = 0; i < lhs.size(); ++i) { expected[i] = lhs[i] - rhs[i]; } ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); } XLA_TEST_F(ArrayElementwiseOpTest, CmpTwoConstantU64s) { XlaBuilder b(TestName()); std::vector lhs{static_cast(0x8000000000000000ULL)}; Literal lhs_literal = LiteralUtil::CreateR1({lhs}); auto lhs_param = Parameter(&b, 0, lhs_literal.shape(), "lhs_param"); std::vector rhs{static_cast(0x7FFFFFFFFFFFFFFFULL)}; Literal rhs_literal = LiteralUtil::CreateR1({rhs}); auto rhs_param = Parameter(&b, 1, rhs_literal.shape(), "rhs_param"); Lt(lhs_param, rhs_param); ComputeAndCompare(&b, {std::move(lhs_literal), std::move(rhs_literal)}); } TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); XlaBuilder builder(TestName()); std::vector a_values; std::vector b_values; for (int i = 0; i < count; ++i) { a_values.push_back(i / static_cast(count)); b_values.push_back(2 * i / static_cast(count + 2)); } Literal a_literal = LiteralUtil::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(a_literal).ConsumeValueOrDie(); auto a_constant = ConstantR1(&builder, a_values); auto a_param = Parameter(&builder, 0, a_literal.shape(), "a_param"); Literal b_literal = LiteralUtil::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(b_literal).ConsumeValueOrDie(); auto b_constant = Parameter(&builder, 1, a_literal.shape(), "b_param"); auto b_param = ConstantR1(&builder, b_values); auto sum1 = Add(a_constant, b_constant); auto sum2 = Add(a_constant, b_param); auto sum3 = Add(a_param, b_constant); auto sum4 = Add(a_param, b_param); auto sum = Add(sum1, sum2); sum = Add(sum, sum3); sum = Add(sum, sum4); std::vector expected; for (int64 i = 0; i < count; ++i) { expected.push_back(4 * (a_values[i] + b_values[i])); } ComputeAndCompareR1(&builder, expected, {a_data.get(), b_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); auto b = ConstantR1(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); Sub(a, b); ComputeAndCompareR1(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-1, 0, 2, 1000000000}); auto b = ConstantR1(&builder, {-1, 2, 1, -1}); Sub(a, b); ComputeAndCompareR1(&builder, {0, -2, 1, 1000000001}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Sub(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); auto b = ConstantR1( &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); Sub(a, b); ComputeAndCompareR1( &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Sub(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = ConstantR1(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); Div(a, b); ComputeAndCompareR1(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } class IntegerDivideOpTest : public ArrayElementwiseOpTest { protected: template void TestDivRem(absl::Span dividends, absl::Span divisors, absl::Span quotients, absl::Span remainders) { { XlaBuilder builder(TestName()); XlaOp dividend; XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); Div(dividend, divisor); ComputeAndCompareR1(&builder, quotients, {dividend_data.get(), divisor_data.get()}); } // Test with a compile-time constant divisor. { XlaBuilder builder(TestName()); XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); Div(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, quotients, {dividend_data.get()}); } { XlaBuilder builder(TestName()); XlaOp dividend; XlaOp divisor; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); auto divisor_data = CreateR1Parameter(divisors, 1, "divisor", &builder, &divisor); Rem(dividend, divisor); ComputeAndCompareR1(&builder, remainders, {dividend_data.get(), divisor_data.get()}); } // Test with a compile-time constant divisor. { XlaBuilder builder(TestName()); XlaOp dividend; auto dividend_data = CreateR1Parameter(dividends, 0, "dividend", &builder, ÷nd); Rem(dividend, ConstantR1(&builder, divisors)); ComputeAndCompareR1(&builder, remainders, {dividend_data.get()}); } } }; XLA_TEST_F(IntegerDivideOpTest, DivS32s) { // clang-format off // Some interesting values to test. std::vector vals = { INT32_MIN, INT32_MIN + 1, INT32_MIN + 2, -0x40000000, -0x3fffffff, -271181, -1309, -17, -10, -5, -3, -2, -1, 0, 1, 2, 3, 5, 10, 17, 26, 101, 7919, 0x40000000, INT32_MAX - 2, INT32_MAX - 1, INT32_MAX}; // clang-format on std::vector dividends, divisors, quotients, remainders; for (int32 divisor : vals) { if (divisor != 0) { for (int32 dividend : vals) { // Avoid integer overflow. if (dividend != INT32_MIN || divisor != -1) { dividends.push_back(dividend); divisors.push_back(divisor); quotients.push_back(dividend / divisor); remainders.push_back(dividend % divisor); } } } } TestDivRem(dividends, divisors, quotients, remainders); } XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) { std::vector dividends = {5, INT32_MIN}, divisors = {0, -1}, quotients = {-1, INT32_MIN}, remainders = {5, 0}; TestDivRem(dividends, divisors, quotients, remainders); } XLA_TEST_F(IntegerDivideOpTest, DivU32s) { // clang-format off // Some interesting values to test. std::vector vals = { 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0xABCDEF12, 0xCAFEBEEF, 0x80000000, 0x80000001, UINT32_MAX - 2, UINT32_MAX - 1, UINT32_MAX}; // clang-format on std::vector dividends, divisors, quotients, remainders; for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { dividends.push_back(dividend); divisors.push_back(divisor); quotients.push_back(dividend / divisor); remainders.push_back(dividend % divisor); } } } TestDivRem(dividends, divisors, quotients, remainders); } XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) { std::vector dividends = {5}, divisors = {0}, quotients = {-1}, remainders = {5}; TestDivRem(dividends, divisors, quotients, remainders); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); auto b = ConstantR1(&builder, {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); Div(a, b); ComputeAndCompareR1( &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Div(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); auto b = ConstantR1( &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); Rem(a, b); ComputeAndCompareR1( &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Rem(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); auto b = ConstantR1( &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); Rem(a, b); ComputeAndCompareR1( &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto b = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); Mul(a, b); ComputeAndCompareR1(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { std::vector data = {0, 1, -1, 1234, 0x1a243514, std::numeric_limits::max(), std::numeric_limits::min()}; // Form the test data set using all products of 'data' with itself. std::vector a_data, b_data, expected; for (int32 a : data) { for (int32 b : data) { a_data.push_back(a); b_data.push_back(b); expected.push_back(static_cast(a) * static_cast(b)); } } XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, a_data); auto b = ConstantR1(&builder, b_data); Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Mul(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { std::vector data = {0, 1, 0xDEADBEEF, 1234, 0x1a243514, 0xFFFFFFFF, 0x80808080}; // Form the test data set using all products of 'data' with itself. std::vector a_data, b_data, expected; for (uint32 a : data) { for (uint32 b : data) { a_data.push_back(a); b_data.push_back(b); expected.push_back(a * b); } } XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, a_data); auto b = ConstantR1(&builder, b_data); Mul(a, b); ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); auto b = ConstantR1(&builder, {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); Mul(a, b); ComputeAndCompareR1( &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Mul(a, b); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {false, false, true, true}); auto b = ConstantR1(&builder, {false, true, false, true}); And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{false, false}, {true, true}}); auto b = ConstantR2(&builder, {{false, true}, {false, true}}); And(a, b); Array2D expected_array({{false, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, -1, -8}); auto b = ConstantR1(&builder, {5, -7, 12}); And(a, b); ComputeAndCompareR1(&builder, {0, -7, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{0, -5}, {-1, 5}}); auto b = ConstantR2(&builder, {{1, -6}, {4, 5}}); And(a, b); Array2D expected_array({{0, -6}, {4, 5}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, 1, 8}); auto b = ConstantR1(&builder, {5, 7, 12}); And(a, b); ComputeAndCompareR1(&builder, {0, 1, 8}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{0, 1}, {3, 8}}); auto b = ConstantR2(&builder, {{1, 0}, {7, 6}}); And(a, b); Array2D expected_array({{0, 0}, {3, 0}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); And(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {false, false, true, true}); auto b = ConstantR1(&builder, {false, true, false, true}); Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{false, false}, {true, true}}); auto b = ConstantR2(&builder, {{false, true}, {false, true}}); Or(a, b); Array2D expected_array({{false, true}, {true, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, -1, 8}); auto b = ConstantR1(&builder, {5, -7, 4}); Or(a, b); ComputeAndCompareR1(&builder, {5, -1, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); Or(a, b); Array2D expected_array({{5, -1}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, 1, 8}); auto b = ConstantR1(&builder, {5, 7, 4}); Or(a, b); ComputeAndCompareR1(&builder, {5, 7, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); Or(a, b); Array2D expected_array({{5, 7}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {false, false, true, true}); auto b = ConstantR1(&builder, {false, true, false, true}); Xor(a, b); ComputeAndCompareR1(&builder, {false, true, true, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{false, false}, {true, true}}); auto b = ConstantR2(&builder, {{false, true}, {false, true}}); Xor(a, b); Array2D expected_array({{false, true}, {true, false}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Xor(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, -1, 8}); auto b = ConstantR1(&builder, {5, -7, 4}); Xor(a, b); ComputeAndCompareR1(&builder, {5, 6, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{0, -1}, {8, 8}}); auto b = ConstantR2(&builder, {{5, -7}, {4, 1}}); Xor(a, b); Array2D expected_array({{5, 6}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Xor(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, 1, 8}); auto b = ConstantR1(&builder, {5, 7, 4}); Xor(a, b); ComputeAndCompareR1(&builder, {5, 6, 12}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{0, 1}, {8, 8}}); auto b = ConstantR2(&builder, {{5, 7}, {4, 1}}); Xor(a, b); Array2D expected_array({{5, 6}, {12, 9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); auto b = ConstantR1(&builder, {}); Xor(a, b); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {false, true, true, false}); Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{false, true}, {true, false}}); Not(a); Array2D expected_array({{true, false}, {false, true}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-1, 0, 1}); Not(a); ComputeAndCompareR1(&builder, {0, -1, -2}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{-1, 0}, {1, 8}}); Not(a); Array2D expected_array({{0, -1}, {-2, -9}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, 4294967295}); Not(a); ComputeAndCompareR1(&builder, {4294967295, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{0, 4294967295}, {1, 4294967294}}); Not(a); Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); ComputeAndCompareR2(&builder, expected_array, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {}); Not(a); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {static_cast(0x12345678), static_cast(0xF0001000), 1, 3, 77, 1, -3, 77}); auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, -1}); ShiftLeft(a, b); ComputeAndCompareR1(&builder, {static_cast(0x23456780), 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {static_cast(0x92345678), static_cast(0x10001000), 1, 3, 77, 1, -3, 77}); auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, -1}); ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, {static_cast(0xF9234567), static_cast(0x00100010), 0, 0, 19, 0, -1, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {static_cast(0x92345678), static_cast(0x10001000), 1, 3, 77, 1, -3, 77}); auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, -1}); ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); auto b = ConstantR1(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u}); ShiftLeft(a, b); ComputeAndCompareR1( &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); auto b = ConstantR1(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u}); ShiftRightArithmetic(a, b); ComputeAndCompareR1( &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); auto b = ConstantR1(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u}); ShiftRightLogical(a, b); ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN}); Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); auto rhs = ConstantR1(&builder, {}); Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); Ge(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); Gt(lhs, rhs); ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); Le(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN}); Lt(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); auto rhs = ConstantR1(&builder, {}); Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, {NAN, 0.0f}, {1.0f, 6.0f}}); auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, {1.0f, 5.0f}, {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); Eq(lhs, rhs); ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); auto rhs = ConstantR1(&builder, {}); Eq(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {{-2.5f, 10.0f}, {1.0f, 25.5f}, {2.25f, -3.0f}, {NAN, 0.0f}, {1.0f, 6.0f}}); auto rhs = ConstantR1(&builder, {{0.0f, 10.0f}, {1.0f, 5.0f}, {2.25f, -3.0f}, {10.0f, 0.0f}, {1.0f, NAN}}); Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN}); Ne(lhs, rhs); ComputeAndCompareR1(&builder, {true, false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {min, min, min, 0, 0, 0, max, max, max}); auto rhs = ConstantR1(&builder, {min, 0, max, -1, 0, 1, min, 0, max}); Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); Eq(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, false, true, false, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); Ne(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, true, false, true, true, true, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); Ge(lhs, rhs); ComputeAndCompareR1( &builder, {true, false, false, true, true, false, true, true, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); Gt(lhs, rhs); ComputeAndCompareR1( &builder, {false, false, false, true, false, false, true, true, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); Le(lhs, rhs); ComputeAndCompareR1( &builder, {true, true, true, false, true, true, false, false, true}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {0, 0, 0, 5, 5, 5, max, max, max}); auto rhs = ConstantR1(&builder, {0, 1, max, 4, 5, 6, 0, 1, max}); Lt(lhs, rhs); ComputeAndCompareR1( &builder, {false, true, true, false, false, true, false, false, false}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f}); auto rhs = ConstantR1(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f}); Pow(lhs, rhs); ComputeAndCompareR1( &builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {-2.0f, -0.6f, -0.6f, 0.0f}); auto rhs = ConstantR1(&builder, {0.5f, 0.6f, -0.6f, -0.6f}); Pow(lhs, rhs); ComputeAndCompareR1(&builder, {NAN, NAN, NAN, INFINITY}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); auto rhs = ConstantR1(&builder, {}); Pow(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } // Some Pow cases that can be implemented more efficiently. XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { XlaBuilder b(TestName()); std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; Literal param_literal = LiteralUtil::CreateR1(values); std::unique_ptr param_data = client_->TransferToServer(param_literal).ConsumeValueOrDie(); auto sum = ConstantR0(&b, 0.0f); auto param = Parameter(&b, 0, param_literal.shape(), "param"); for (float exponent : exponents) { sum = Add(sum, Pow(param, ConstantR0(&b, exponent))); } std::vector expected; for (auto value : values) { float sum = 0.0f; for (float exponent : exponents) { sum += std::pow(value, exponent); } expected.push_back(sum); } ComputeAndCompareR1(&b, expected, {param_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Pow(Exp(param0), param1); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = std::pow(std::exp(values0[i]), values1[i]); } ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Log(Pow(param0, param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = std::log(std::pow(values0[i], values1[i])); } ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Mul(Exp(param0), Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = std::exp(values0[i]) * std::exp(values1[i]); } ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); Div(param0, Exp(param1)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = values0[i] / std::exp(values1[i]); } ComputeAndCompareR1(&b, expected, {data0.get(), data1.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(literal2).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(Div(param0, param1), param2); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = (values0[i] / values1[i]) / values2[i]; } ComputeAndCompareR1( &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(literal2).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Div(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = values0[i] / (values1[i] / values2[i]); } ComputeAndCompareR1( &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(literal2).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); Div(param0, Pow(param1, param2)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = values0[i] / std::pow(values1[i], values2[i]); } ComputeAndCompareR1( &b, expected, {data0.get(), data1.get(), data2.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; std::vector values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f}; std::vector values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f}; Literal literal0 = LiteralUtil::CreateR1(values0); std::unique_ptr data0 = client_->TransferToServer(literal0).ConsumeValueOrDie(); Literal literal1 = LiteralUtil::CreateR1(values1); std::unique_ptr data1 = client_->TransferToServer(literal1).ConsumeValueOrDie(); Literal literal2 = LiteralUtil::CreateR1(values2); std::unique_ptr data2 = client_->TransferToServer(literal2).ConsumeValueOrDie(); Literal literal3 = LiteralUtil::CreateR1(values3); std::unique_ptr data3 = client_->TransferToServer(literal3).ConsumeValueOrDie(); auto param0 = Parameter(&b, 0, literal0.shape(), "param0"); auto param1 = Parameter(&b, 1, literal1.shape(), "param1"); auto param2 = Parameter(&b, 2, literal2.shape(), "param2"); auto param3 = Parameter(&b, 3, literal3.shape(), "param2"); Div(Div(param0, param1), Div(param2, param3)); std::vector expected(values0.size()); for (int64 i = 0; i < values0.size(); ++i) { expected[i] = (values0[i] / values1[i]) / (values2[i] / values3[i]); } ComputeAndCompareR1( &b, expected, {data0.get(), data1.get(), data2.get(), data3.get()}, error_spec_); } TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { const int count = GetParam(); XlaBuilder builder(TestName()); std::vector values; values.reserve(count); for (int i = 0; i < count; ++i) { values.push_back(i / static_cast(count)); } auto x = ConstantR1(&builder, values); Pow(x, ConstantR0(&builder, 2.0f)); std::vector expected; expected.reserve(values.size()); for (float value : values) { expected.push_back(value * value); } ComputeAndCompareR1(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) { XlaBuilder builder(TestName()); Array4D values(2, 2, 2, 2); std::vector values_vector; std::vector expected_vector; for (int i = 0; i < values.num_elements(); ++i) { values_vector.push_back(static_cast(i) / values.num_elements()); expected_vector.push_back(values_vector.back() * values_vector.back()); } values.SetValues(values_vector); Array4D expected(2, 2, 2, 2, expected_vector); auto x = ConstantR4FromArray4D(&builder, values); Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { XlaBuilder builder(TestName()); Array4D values(2, 2, 0, 2); Array4D expected(2, 2, 0, 2); auto x = ConstantR4FromArray4D(&builder, values); Pow(x, ConstantR0(&builder, 2.0f)); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); auto rhs = ConstantR1(&builder, {}); Min(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); Min(lhs, rhs); ComputeAndCompareR1(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = ConstantR1(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f}); auto rhs = ConstantR1(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN}); Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); auto rhs = ConstantR1(&builder, {}); Max(lhs, rhs); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { XlaBuilder builder(TestName()); SetFastMathDisabled(true); auto lhs = ConstantR1(&builder, {1.0, 1.0, 2.25, NAN, 6.0}); auto rhs = ConstantR1(&builder, {2.0, -5.0, 1.0, 10.0, NAN}); Max(lhs, rhs); ComputeAndCompareR1(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = ConstantR1( &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); Max(x, y); std::vector expected = {min, max, 0, -1, 0, 0, 0, 1, 1, 10, max, max, max}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) { const int32 min = std::numeric_limits::min(); const int32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); auto y = ConstantR1( &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); Min(x, y); std::vector expected = {min, min, min, -10, -1, -1, 0, 0, 0, 1, 0, max, min}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); Max(x, y); std::vector expected = {0, 1, 1, 1, 10, max, max, max}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) { const uint32 max = std::numeric_limits::max(); XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {0, 0, 1, 1, 1, max, max, max}); auto y = ConstantR1(&builder, {0, 1, 0, 1, 10, 0, 234234, max}); Min(x, y); std::vector expected = {0, 0, 0, 1, 1, 0, 234234, max}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { XlaBuilder builder(TestName()); auto x = ConstantR1( &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = ConstantR1( &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); Max(x, y); std::vector expected = {-0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; ComputeAndCompareR1(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { XlaBuilder builder(TestName()); auto u = ConstantR1(&builder, {3.5}); auto v = ConstantR1(&builder, {}); Max(u, v); ComputeAndCompareR1(&builder, {}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { for (int broadcast_dim : {0, 1}) { XlaBuilder builder(TestName()); auto u = ConstantR1(&builder, {3.5}); auto v = ConstantR2FromArray2D(&builder, Array2D(0, 2)); Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); } } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); auto m = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {}); auto m = ConstantR2(&builder, {{}, {}}); Max(v, m, /*broadcast_dimensions=*/{1}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { XlaBuilder builder(TestName()); auto scalar = ConstantR0(&builder, 2); Array3D a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); auto array = ConstantR3FromArray3D(&builder, a_3d); Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); ComputeAndCompareR3(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { XlaBuilder builder(TestName()); auto scalar = ConstantR0(&builder, 2); Array3D a_3d(2, 0, 3); auto array = ConstantR3FromArray3D(&builder, a_3d); Max(array, scalar, /*broadcast_dimensions=*/{}); Array3D expected(2, 0, 3); ComputeAndCompareR3(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { XlaBuilder builder(TestName()); auto m = ConstantR2(&builder, {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); auto v = ConstantR1(&builder, {-10.2f, 16.4f}); Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { XlaBuilder builder(TestName()); auto m = ConstantR2(&builder, {{}, {}}); auto v = ConstantR1(&builder, {-10.2f, 16.4f}); Min(m, v, /*broadcast_dimensions=*/{0}); Array2D expected({{}, {}}); ComputeAndCompareR2(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { XlaBuilder builder(TestName()); auto array2d = ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); auto array4d = ConstantR4FromArray4D( &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected( {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { XlaBuilder builder(TestName()); auto array2d = ConstantR2(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); Array4D arg(2, 2, 0, 3); auto array4d = ConstantR4FromArray4D(&builder, arg); Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); Array4D expected(2, 2, 0, 3); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); Min(x, y); std::vector expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { XlaBuilder builder(TestName()); auto x = ConstantR1(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); auto y = ConstantR1(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); Max(x, y); std::vector expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; ComputeAndCompareR1(&builder, expected, {}); } XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-3, 26, 2, -1, 1}); auto b = ConstantR1(&builder, {10, 5, 1, 10, -10}); Rem(a, b); ComputeAndCompareR1(&builder, {-3, 1, 0, -1, 1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { XlaBuilder builder(TestName()); auto minimum = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto argument = ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); auto maximum = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { XlaBuilder builder(TestName()); auto minimum = ConstantR0(&builder, 0.0f); auto argument = ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto maximum = ConstantR0(&builder, 5.0f); Clamp(minimum, argument, maximum); ComputeAndCompareR1(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { XlaBuilder builder(TestName()); auto min_scalar = ConstantR0(&builder, 0.0f); auto min_vector = ConstantR1(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = ConstantR1(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = ConstantR0(&builder, 3.0f); auto max_vector = ConstantR1(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. Add(Add(Clamp(min_vector, arg_vector, max_scalar), Clamp(min_scalar, arg_vector, max_vector)), Add(Clamp(min_vector, arg_vector, max_vector), Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { XlaBuilder builder(TestName()); auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0, -5}); auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4, 10}); auto max_vector = ConstantR1(&builder, {3, 0, 25, 5, 123, -1}); Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { XlaBuilder builder(TestName()); auto min_scalar = ConstantR0(&builder, 0); auto min_vector = ConstantR1(&builder, {1, -6, 1, 2, 0}); auto arg_vector = ConstantR1(&builder, {2, 10, -5, 1, 4}); auto max_scalar = ConstantR0(&builder, 3); auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. Add(Add(Clamp(min_vector, arg_vector, max_scalar), Clamp(min_scalar, arg_vector, max_vector)), Add(Clamp(min_vector, arg_vector, max_vector), Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { XlaBuilder builder(TestName()); auto min_vector = ConstantR1(&builder, {1, 2, 1, 2, 0, ~0u - 4}); auto arg_vector = ConstantR1(&builder, {2, 10, 5, 1, 4, 10}); auto max_vector = ConstantR1(&builder, {3, 5, 25, 5, 123, ~0u}); Clamp(min_vector, arg_vector, max_vector); ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { XlaBuilder builder(TestName()); auto min_scalar = ConstantR0(&builder, 0); auto min_vector = ConstantR1(&builder, {1, 0, 1, 2, 0}); auto arg_vector = ConstantR1(&builder, {2, 10, 0, 1, 4}); auto max_scalar = ConstantR0(&builder, 3); auto max_vector = ConstantR1(&builder, {3, 1, 25, 5, 123}); // Perform clamp with broadcasted scalar and vector. Add(Add(Clamp(min_vector, arg_vector, max_scalar), Clamp(min_scalar, arg_vector, max_vector)), Add(Clamp(min_vector, arg_vector, max_vector), Clamp(min_scalar, arg_vector, max_scalar))); ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Literal param1_literal = LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = client_->TransferToServer(param1_literal).ConsumeValueOrDie(); auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); ComputeAndCompareR1(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, {param0_data.get(), param1_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).ConsumeValueOrDie(); Literal param1_literal = LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = client_->TransferToServer(param1_literal).ConsumeValueOrDie(); auto p0 = Parameter(&builder, 0, param0_literal.shape(), "param0"); auto p1 = Parameter(&builder, 1, param1_literal.shape(), "param1"); Add(p0, p1); Array3D expected(0, 7, 0); ComputeAndCompareR3( &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { XlaBuilder builder(TestName()); Literal param0_literal = LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(param0_literal).ConsumeValueOrDie(); auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); auto p = Parameter(&builder, 0, param0_literal.shape(), "param0"); Add(a, p); ComputeAndCompareR1(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, {param0_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); Cos(a); ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f}); Sin(a); ComputeAndCompareR1(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); auto b = ConstantR1(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); Atan2(a, b); ComputeAndCompareR1( &builder, {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {-2.5f, 3.14f, 2.25f}); Tanh(a); ComputeAndCompareR1(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh // implementation on XLA CPU. XlaBuilder builder(TestName()); auto input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81, 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35, 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(input_literal)); auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Tanh(input); ComputeAndCompareR1( &builder, {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975, -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293, 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649, -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336, -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895, 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621, 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107, 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593, 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573, -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558, -0.65565848, 0.88789743, 0.83566397, 0.78287679}, {input_data.get()}, // The error spec is unusually high here to account for the fact that we // use a rational interpolant to approximate tanh. ErrorSpec(0.004, 0.004)); } XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. XlaBuilder builder(TestName()); // Just to help make sense of the scales here -- exp(89) saturates float32 and // exp(-10) is smaller than our error spec. Literal input_literal = LiteralUtil::CreateR1( {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5, -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4, 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3, 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2, 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1, 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2, 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(input_literal)); auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Exp(input); std::vector expected_result; int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { expected_result.push_back(std::exp(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { // The input tensor is large enough to exercise the vectorized exp // implementation on XLA CPU. XlaBuilder builder(TestName()); Literal input_literal = LiteralUtil::CreateR1( {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07, 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09, 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12, 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15, 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17, 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20, 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22, 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25, 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28, 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30, 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, client_->TransferToServer(input_literal)); auto input = Parameter(&builder, 0, input_literal.shape(), "input"); Log(input); std::vector expected_result; int64 input_size = input_literal.shape().dimensions(0); expected_result.reserve(input_size); for (int64 i = 0; i < input_size; i++) { expected_result.push_back(std::log(input_literal.Get({i}))); } ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { XlaBuilder builder(TestName()); auto a = ConstantR1( &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); Clz(a); ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); Clz(a); ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / // b -----/ / // c---------------------/ XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {1.1f, 2.2f, 3.3f, 4.4f}); auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); auto add = Add(a, b); Add(add, c); ComputeAndCompareR1(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { // b ------ (add) --------- (add) // / / // c -----/ / // a---------------------/ XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); auto add = Add(b, c); Add(a, add); ComputeAndCompareR1(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) { // a ----- (neg) ----- (add) // / // b ----- (neg) ----/ XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); auto neg_a = Neg(a); auto neg_b = Neg(b); Add(neg_a, neg_b); ComputeAndCompareR1(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { // a ------ (add) ------------\ // / \ // b -----/ (add) // / // c ------ (add) ------------/ // / // d -----/ XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {91.1f, 2.2f, 3.3f, 4.4f}); auto b = ConstantR1(&builder, {2.1f, 3.2f, 4.3f, 5.4f}); auto c = ConstantR1(&builder, {-3.3f, -15.5f, -7.7f, -29.9f}); auto d = ConstantR1(&builder, {-19.0f, 10.0f, -40.0f, 20.2f}); auto add_ab = Add(a, b); auto add_cd = Add(c, d); Add(add_ab, add_cd); ComputeAndCompareR1(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = ConstantR2(&builder, {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); Add(a, b); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { // Add a scalar + matrix. XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = ConstantR0(&builder, 3.0f); Add(scalar, a); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { // Add a matrix + scalar. XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto scalar = ConstantR0(&builder, 3.0f); Add(a, scalar); Array2D expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches // only dim 0 of the matrix. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {20.0f, 40.0f, 60.0f}); // clang-format off auto m = ConstantR2(&builder, { {-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); // clang-format on Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array( {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { // Test broadcasting in Eq comparison. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {42, 73}); auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); // This test exercises both possible broadcast dimensions for a vector/matrix // comparison. auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1}); auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0}); Tuple(&builder, {cmp_dim_0, cmp_dim_1}); auto expected = LiteralUtil::MakeTupleFromSlices( {LiteralUtil::CreateR2({{true, true}, {true, false}}), LiteralUtil::CreateR2({{true, false}, {false, false}})}); ComputeAndCompareTuple(&builder, expected, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { // Test broadcasting in Ne comparison. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {42, 73}); auto m = ConstantR2(&builder, {{42, 73}, {42, 52}}); Ne(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,2] { { 00 }, { 01 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { // Test broadcasting in Ge comparison. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {1, 2, 3, 4}); auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); Ge(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1100 }, { 0001 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { // Test broadcasting in Gt comparison. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {1, 2, 3, 4}); auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); Gt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0100 }, { 0000 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { // Test broadcasting in Le comparison. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {1, 2, 3, 4}); auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); Le(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 1011 }, { 1111 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { // Test broadcasting in Lt comparison. XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {1, 2, 3, 4}); auto m = ConstantR2(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}}); Lt(v, m, /*broadcast_dimensions=*/{1}); const string expected = R"(pred[2,4] { { 0011 }, { 1110 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op // arguments is reversed. XlaBuilder builder(TestName()); auto m = ConstantR2(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); auto v = ConstantR1(&builder, {2.0f, 4.0f, 6.0f}); Mul(m, v, /*broadcast_dimensions=*/{1}); Array2D expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {3, 1} // The result has shape {3, 2}, where md is broadcast over m auto m = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = ConstantR2(&builder, {{10.0f, 20.0f, 30.0f}}); Add(m, md); Array2D expected_array( {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { // Tests broadcasting for arrays with degenerate (size == 1) dimensions. XlaBuilder builder(TestName()); // m's shape in XLA notation is {3, 2} // md's shape in XLA notation is {1, 2} // The result has shape {3, 2}, where md is broadcast over m auto m = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto md = ConstantR2(&builder, {{10.0f}, {20.0f}}); Add(m, md); Array2D expected_array( {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { // Tests broadcasting for two degenerate arrays. This kind of broadcasting // effectively creates an "outer product" operation. // This is taken from the Numpy docs example at: // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html XlaBuilder builder(TestName()); // a's shape in XLA notation is {1, 4} // b's shape in XLA notation is {3, 1} // The result has shape {3, 4}. auto a = ConstantR2(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}}); auto b = ConstantR2(&builder, {{1.0f, 2.0f, 3.0f}}); Add(a, b); Array2D expected_array({{1.0f, 2.0f, 3.0f}, {11.0f, 12.0f, 13.0f}, {21.0f, 22.0f, 23.0f}, {31.0f, 32.0f, 33.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { // Add together a (2,2) array and a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {20.0f, 40.0f}); auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); Add(v, m, /*broadcast_dimensions=*/{1}); Array2D expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { // Add together a (2,2) array and a (2) array, using dimension 1 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); auto v = ConstantR1(&builder, {20.0f, 40.0f}); auto m = ConstantR2(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}}); Add(v, m, /*broadcast_dimensions=*/{0}); Array2D expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { // Binary add of two R3s together XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); auto b = ConstantR3FromArray3D(&builder, b_3d); Add(a, b); Array3D expected_3d( {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}}); ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}, }); // clang-format on auto a = ConstantR3FromArray3D(&builder, a_3d); auto v = ConstantR1(&builder, {10.0f, 20.0f}); Add(a, v, /*broadcast_dimensions=*/{2}); Array3D expected_3d( {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}}); ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for // broadcasting (though there are two ways to broadcast these shapes). XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}, }); // clang-format on auto a = ConstantR3FromArray3D(&builder, a_3d); auto v = ConstantR1(&builder, {10.0f, 20.0f}); Add(a, v, /*broadcast_dimensions=*/{0}); // clang-format off Array3D expected_3d({ {{11.0f, 12.0f}, {13.0f, 14.0f}, {15.0f, 16.0f}}, {{27.0f, 28.0f}, {29.0f, 30.0f}, {31.0f, 32.0f}}, }); // clang-format on ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} // for broadcasting. XlaBuilder builder(TestName()); // clang-format off Array3D a_3d({ {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}, }); auto a = ConstantR3FromArray3D(&builder, a_3d); auto m = ConstantR2(&builder, { {10.0f, 20.0f, 30.0f}, {40.0f, 50.0f, 60.0f}, }); Add(a, m, /*broadcast_dimensions=*/{0, 1}); Array3D expected_3d({ {{11.0f, 12.0f}, {23.0f, 24.0f}, {35.0f, 36.0f}}, {{47.0f, 48.0f}, {59.0f, 60.0f}, {71.0f, 72.0f}}, }); // clang-format on ComputeAndCompareR3(&builder, expected_3d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { // Comparison between two 3D arrays of compatible shapes: // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. XlaBuilder builder(TestName()); Array3D a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); auto a = ConstantR3FromArray3D(&builder, a_3d); Array3D b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); auto b = ConstantR3FromArray3D(&builder, b_3d); Gt(a, b); Array3D expected_3d( {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); const string expected = R"(pred[2,3,2] { { { 01 }, { 00 }, { 00 } }, { { 01 }, { 10 }, { 01 } } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> operand_b_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); float value = 0.0; for (int64 p = 0; p < 2; ++p) { for (int64 z = 0; z < 3; ++z) { for (int64 y = 0; y < 4; ++y) { for (int64 x = 0; x < 5; ++x) { (*operand_a_4d)(p, z, y, x) = value; (*operand_b_4d)(p, z, y, x) = 2.0 * value; (*expected_4d)(p, z, y, x) = 3.0 * value; value += 0.1; } } } } auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); auto b = ConstantR4FromArray4D(&builder, *operand_b_4d); Add(a, b); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { XlaBuilder builder(TestName()); std::unique_ptr> operand_a_4d(new Array4D(2, 3, 4, 5)); std::unique_ptr> expected_4d(new Array4D(2, 3, 4, 5)); std::vector operand_b_1d(3); std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0); float value = 0.0; for (int64 p = 0; p < 2; ++p) { for (int64 z = 0; z < 3; ++z) { for (int64 y = 0; y < 4; ++y) { for (int64 x = 0; x < 5; ++x) { (*operand_a_4d)(p, z, y, x) = value; (*expected_4d)(p, z, y, x) = value + operand_b_1d[z]; value += 0.1; } } } } auto a = ConstantR4FromArray4D(&builder, *operand_a_4d); auto b = ConstantR1(&builder, operand_b_1d); Add(a, b, {1}); ComputeAndCompareR4(&builder, *expected_4d, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { constexpr int d0 = 16; constexpr int d1 = 16; constexpr int d2 = 2; constexpr int d3 = 2; Array4D r4(d0, d1, d2, d3); r4.Fill(1.0); std::vector r1(d1); std::iota(r1.begin(), r1.end(), 1.0); XlaBuilder builder(TestName()); Literal a_literal = LiteralUtil::CreateR4FromArray4DWithLayout( r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = ConstantLiteral(&builder, a_literal); auto b = ConstantR1(&builder, r1); Add(a, b, {1}); for (int i0 = 0; i0 < d0; ++i0) { for (int i1 = 0; i1 < d1; ++i1) { for (int i2 = 0; i2 < d2; ++i2) { for (int i3 = 0; i3 < d3; ++i3) { r4(i0, i1, i2, i3) += r1[i1]; } } } } ComputeAndCompareR4(&builder, r4, {}, error_spec_); } // Show that we can't add two opaques. XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { XlaBuilder builder(TestName()); auto shape = ShapeUtil::MakeOpaqueShape(); auto x = Parameter(&builder, 0, shape, "x"); Add(x, x); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().ToString(), ::testing::ContainsRegex( "Expected array argument for lhs of binary operation")); } XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = ConstantR2(&builder, {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); Add(a, b, /*broadcast_dimensions=*/{0, 1}); Array2D expected_array( {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); } XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { XlaBuilder builder(TestName()); auto a = ConstantR2(&builder, {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); auto b = ConstantR2(&builder, {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); Add(a, b, /*broadcast_dimensions=*/{1, 0}); auto computation_status = builder.Build(); ASSERT_FALSE(computation_status.ok()); EXPECT_THAT(computation_status.status().error_message(), ::testing::ContainsRegex("must.*be the identity")); } // Regression test for b/31927799. "slice - y" is fused and requires implicit // broadcast. XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { XlaBuilder builder(TestName()); auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); auto y_literal = LiteralUtil::CreateR1({4, 5}); auto x_data = client_->TransferToServer(x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(y_literal).ConsumeValueOrDie(); auto x = Parameter(&builder, 0, x_literal.shape(), "x"); auto y = Parameter(&builder, 1, y_literal.shape(), "y"); auto slice = Slice(x, {1}, {2}, {1}); Sub(slice, y); ComputeAndCompareR1(&builder, {-2, -3}, {x_data.get(), y_data.get()}, error_spec_); } INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, ArrayElementwiseOpTestParamCount, ::testing::Values(127, 128, 129, 17 * 4096)); } // namespace } // namespace xla