diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc | 1662 |
1 files changed, 1662 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc new file mode 100644 index 0000000000..cf6f9a825c --- /dev/null +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -0,0 +1,1662 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cmath> +#include <limits> +#include <memory> +#include <numeric> +#include <vector> + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ArrayElementwiseOpTest : public ClientLibraryTestBase { + public: + ErrorSpec error_spec_{0.0001}; +}; + +class ArrayElementwiseOpTestParamCount + : public ArrayElementwiseOpTest, + public ::testing::WithParamInterface<int> {}; + +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({}); + auto result = builder.Neg(a); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, NegConstantF32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto result = builder.Neg(a); + + ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, NegConstantS32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({-1, 0, 1, 324, + std::numeric_limits<int32>::min(), + std::numeric_limits<int32>::max()}); + auto result = builder.Neg(a); + + // -min == min for int32 due to an overflow. In C++ it is undefined behavior + // to do this calculation. For XLA we have not specified that, so it + // ought to work. + ComputeAndCompareR1<int32>(&builder, + {1, 0, -1, -324, std::numeric_limits<int32>::min(), + -std::numeric_limits<int32>::max()}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({}); + auto b = builder.ConstantR1<float>({}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { + const int count = GetParam(); + ComputationBuilder builder(client_, TestName()); + std::vector<float> a_values; + std::vector<float> b_values; + for (int i = 0; i < count; ++i) { + a_values.push_back(i / static_cast<float>(count)); + b_values.push_back(2 * i / static_cast<float>(count + 2)); + } + + std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values}); + std::unique_ptr<GlobalData> a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a_constant = builder.ConstantR1<float>(a_values); + auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); + + std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values}); + std::unique_ptr<GlobalData> b_data = + client_->TransferToServer(*b_literal).ConsumeValueOrDie(); + auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); + auto b_param = builder.ConstantR1<float>(b_values); + + auto sum1 = builder.Add(a_constant, b_constant); + auto sum2 = builder.Add(a_constant, b_param); + auto sum3 = builder.Add(a_param, b_constant); + auto sum4 = builder.Add(a_param, b_param); + + auto sum = builder.Add(sum1, sum2); + sum = builder.Add(sum, sum3); + sum = builder.Add(sum, sum4); + + std::vector<float> expected; + for (int64 i = 0; i < count; ++i) { + expected.push_back(4 * (a_values[i] + b_values[i])); + } + + ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f}, + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({}); + auto b = builder.ConstantR1<float>({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000}); + auto b = builder.ConstantR1<int32>({-1, 2, 1, -1}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({}); + auto b = builder.ConstantR1<int32>({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1<int32>(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f}); + auto add = builder.Div(a, b); + + ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({}); + auto b = builder.ConstantR1<float>({}); + auto add = builder.Div(a, b); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>( + {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f}); + auto b = builder.ConstantR1<float>( + {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1<float>( + &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({}); + auto b = builder.ConstantR1<float>({}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<double>( + {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0}); + auto b = builder.ConstantR1<double>( + {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1<double>( + &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); + auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f}, + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({}); + auto b = builder.ConstantR1<float>({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) { + std::vector<int32> data = {0, + 1, + -1, + 1234, + 0x1a243514, + std::numeric_limits<int32>::max(), + std::numeric_limits<int32>::min()}; + // Form the test data set using all products of 'data' with itself. + std::vector<int32> a_data, b_data, expected; + for (int32 a : data) { + for (int32 b : data) { + a_data.push_back(a); + b_data.push_back(b); + expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b)); + } + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>(a_data); + auto b = builder.ConstantR1<int32>(b_data); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1<int32>(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({}); + auto b = builder.ConstantR1<int32>({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1<int32>(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { + std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234, + 0x1a243514, 0xFFFFFFFF, 0x80808080}; + + // Form the test data set using all products of 'data' with itself. + std::vector<uint32> a_data, b_data, expected; + for (uint32 a : data) { + for (uint32 b : data) { + a_data.push_back(a); + b_data.push_back(b); + expected.push_back(a * b); + } + } + + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>(a_data); + auto b = builder.ConstantR1<uint32>(b_data); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, LogicalAnd) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<bool>({false, false, true, true}); + auto b = builder.ConstantR1<bool>({false, true, false, true}); + auto out = builder.LogicalAnd(a, b); + + ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<bool>({}); + auto b = builder.ConstantR1<bool>({}); + auto out = builder.LogicalAnd(a, b); + + ComputeAndCompareR1<bool>(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, LogicalOr) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<bool>({false, false, true, true}); + auto b = builder.ConstantR1<bool>({false, true, false, true}); + auto out = builder.LogicalOr(a, b); + + ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<bool>({}); + auto b = builder.ConstantR1<bool>({}); + auto out = builder.LogicalOr(a, b); + + ComputeAndCompareR1<bool>(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, LogicalNot) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<bool>({false, true, true, false}); + auto out = builder.LogicalNot(a); + + ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<bool>({}); + auto out = builder.LogicalNot(a); + + ComputeAndCompareR1<bool>(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 2.25f, 10.0f, NAN}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({}); + auto rhs = builder.ConstantR1<float>({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Ge(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Gt(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLeF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Le(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLtF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN}); + auto compare = builder.Lt(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareEqS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {true, false, false, false, true, false, false, false, true}, + {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<int32>({}); + auto rhs = builder.ConstantR1<int32>({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1<bool>(&builder, {}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareNeS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {false, true, true, true, false, true, true, true, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGeS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Ge(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {true, false, false, true, true, false, true, true, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGtS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Gt(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {false, false, false, true, false, false, true, true, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLeS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Le(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {true, true, true, false, true, true, false, false, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLtS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max}); + auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max}); + auto compare = builder.Lt(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {false, true, true, false, false, true, false, false, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareEqU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {true, false, false, false, true, false, false, false, true}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareNeU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {false, true, true, true, false, true, true, true, false}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGeU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Ge(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {true, false, false, true, true, false, true, true, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareGtU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Gt(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {false, false, false, true, false, false, true, true, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLeU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Le(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {true, true, true, false, true, true, false, false, true}, {}); +} + +TEST_F(ArrayElementwiseOpTest, CompareLtU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max}); + auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max}); + auto compare = builder.Lt(lhs, rhs); + + ComputeAndCompareR1<bool>( + &builder, {false, true, true, false, false, true, false, false, false}, + {}); +} + +TEST_F(ArrayElementwiseOpTest, PowF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({4.0f, 2.0f, 2.0f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({2.0f, -2.0f, 3.0f, 10.0f, NAN}); + auto minimum = builder.Pow(lhs, rhs); + + ComputeAndCompareR1<float>(&builder, {16.0f, 0.25f, 8.0f, NAN, NAN}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({}); + auto rhs = builder.ConstantR1<float>({}); + auto minimum = builder.Pow(lhs, rhs); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +// Some Pow cases that can be implemented more efficiently. +TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { + ComputationBuilder b(client_, TestName()); + + std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f}; + std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; + + std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values); + std::unique_ptr<GlobalData> param_data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + + auto sum = b.ConstantR0<float>(0.0f); + auto param = b.Parameter(0, param_literal->shape(), "param"); + for (float exponent : exponents) { + sum = b.Add(sum, b.Pow(param, b.ConstantR0<float>(exponent))); + } + + std::vector<float> expected; + for (auto value : values) { + float sum = 0.0f; + for (float exponent : exponents) { + sum += std::pow(value, exponent); + } + expected.push_back(sum); + } + + ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_); +} + +TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) { + const int count = GetParam(); + ComputationBuilder builder(client_, TestName()); + std::vector<float> values; + for (int i = 0; i < count; ++i) { + values.push_back(i / static_cast<float>(count)); + } + auto x = builder.ConstantR1<float>(values); + auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f)); + + std::vector<float> expected; + for (float value : values) { + expected.push_back(value * value); + } + + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, SquareIn4D) { + ComputationBuilder builder(client_, TestName()); + Array4D<float> values(2, 2, 2, 2); + + std::vector<float> values_vector; + std::vector<float> expected_vector; + for (int i = 0; i < values.num_elements(); ++i) { + values_vector.push_back(static_cast<float>(i) / values.num_elements()); + expected_vector.push_back(values_vector.back() * values_vector.back()); + } + values.SetValues(values_vector); + + Array4D<float> expected(2, 2, 2, 2, expected_vector); + + auto x = builder.ConstantR4FromArray4D<float>(values); + auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f)); + + ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) { + ComputationBuilder builder(client_, TestName()); + Array4D<float> values(2, 2, 0, 2); + Array4D<float> expected(2, 2, 0, 2); + + auto x = builder.ConstantR4FromArray4D<float>(values); + auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f)); + + ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); +} + +// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT +// such +// * fmin(NaN, x) = x +// * fmax(NaN, x) = x +// so we only test NAN on CPU. +// +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. +TEST_F(ArrayElementwiseOpTest, MinF32s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f}); + auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f}); +#else + auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN}); +#endif + auto minimum = builder.Min(lhs, rhs); + + ComputeAndCompareR1<float>(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {1.0f, -5.0f, 1.0f}, +#else + {1.0f, -5.0f, 1.0f, 10.0f, 6.0f}, +#endif + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({}); + auto rhs = builder.ConstantR1<float>({}); + auto minimum = builder.Min(lhs, rhs); + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. See comment on MinF32s test above. +XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25}); + auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0}); +#else + auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN}); +#endif + auto minimum = builder.Min(lhs, rhs); + + ComputeAndCompareR1<double>(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {1.0, -5.0, 1.0}, +#else + {1.0, -5.0, 1.0, 10.0, 6.0}, +#endif + {}, error_spec_); +} + +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. See comment on MinF32s test above. +TEST_F(ArrayElementwiseOpTest, MaxF32s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f}); + auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f}); +#else + auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f}); + auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN}); +#endif + auto maximum = builder.Max(lhs, rhs); + + ComputeAndCompareR1<float>(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {2.0f, 1.0f, 2.25f}, +#else + {2.0f, 1.0f, 2.25f, 10.0f, 6.0f}, +#endif + {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1<float>({}); + auto rhs = builder.ConstantR1<float>({}); + auto minimum = builder.Max(lhs, rhs); + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +// TODO(b/28180546): Make this compile in a way that is consistent +// among backends. See comment on MinF32s test above. +XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) { + ComputationBuilder builder(client_, TestName()); +#if !defined(XLA_TEST_BACKEND_CPU) + auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25}); + auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0}); +#else + auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0}); + auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN}); +#endif + auto maximum = builder.Max(lhs, rhs); + + ComputeAndCompareR1<double>(&builder, +#if !defined(XLA_TEST_BACKEND_CPU) + {2.0, 1.0, 2.25}, +#else + {2.0, 1.0, 2.25, 10.0, 6.0}, +#endif + {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, MaxS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1<int32>( + {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = builder.ConstantR1<int32>( + {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + builder.Max(x, y); + + std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0, + 1, 1, 10, max, max, max}; + ComputeAndCompareR1<int32>(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MinS32s) { + const int32 min = std::numeric_limits<int32>::min(); + const int32 max = std::numeric_limits<int32>::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1<int32>( + {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max}); + auto y = builder.ConstantR1<int32>( + {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min}); + builder.Min(x, y); + + std::vector<int32> expected = {min, min, min, -10, -1, -1, 0, + 0, 0, 1, 0, max, min}; + ComputeAndCompareR1<int32>(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MaxU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max}); + auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max}); + builder.Max(x, y); + + std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MinU32s) { + const uint32 max = std::numeric_limits<uint32>::max(); + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max}); + auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max}); + builder.Min(x, y); + + std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max}; + ComputeAndCompareR1<uint32>(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, MaxTenF32s) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1<float>( + {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); + auto y = builder.ConstantR1<float>( + {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0}); + builder.Max(x, y); + + std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0, + 5.0, 6.0, 7.0, 8.0, 9.0}; + ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) { + ComputationBuilder builder(client_, TestName()); + auto u = builder.ConstantR1<float>({3.5}); + auto v = builder.ConstantR1<float>({}); + builder.Max(u, v); + + ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) { + for (int broadcast_dim : {0, 1}) { + ComputationBuilder builder(client_, TestName()); + auto u = builder.ConstantR1<float>({3.5}); + auto v = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2)); + builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim}); + + ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_); + } +} + +TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f}); + auto m = + builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + builder.Max(v, m, /*broadcast_dimensions=*/{1}); + + Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<float>({}); + auto m = builder.ConstantR2<float>({{}, {}}); + builder.Max(v, m, /*broadcast_dimensions=*/{1}); + + Array2D<float> expected({{}, {}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) { + ComputationBuilder builder(client_, TestName()); + auto scalar = builder.ConstantR0<int32>(2); + Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}}); + auto array = builder.ConstantR3FromArray3D<int32>(a_3d); + builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + + Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}}); + ComputeAndCompareR3<int32>(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) { + ComputationBuilder builder(client_, TestName()); + auto scalar = builder.ConstantR0<int32>(2); + Array3D<int32> a_3d(2, 0, 3); + auto array = builder.ConstantR3FromArray3D<int32>(a_3d); + builder.Max(array, scalar, /*broadcast_dimensions=*/{}); + + Array3D<int32> expected(2, 0, 3); + ComputeAndCompareR3<int32>(&builder, expected, {}); +} + +TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) { + ComputationBuilder builder(client_, TestName()); + auto m = + builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}}); + auto v = builder.ConstantR1<float>({-10.2f, 16.4f}); + builder.Min(m, v, /*broadcast_dimensions=*/{0}); + + Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantR2<float>({{}, {}}); + auto v = builder.ConstantR1<float>({-10.2f, 16.4f}); + builder.Min(m, v, /*broadcast_dimensions=*/{0}); + + Array2D<float> expected({{}, {}}); + ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) { + ComputationBuilder builder(client_, TestName()); + auto array2d = + builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + auto array4d = builder.ConstantR4FromArray4D<float>( + {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}}, + {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}}); + builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + + Array4D<float> expected( + {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}}, + {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}}); + ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + auto array2d = + builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}}); + Array4D<float> arg(2, 2, 0, 3); + auto array4d = builder.ConstantR4FromArray4D<float>(arg); + builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3}); + + Array4D<float> expected(2, 2, 0, 3); + ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + builder.Min(x, y); + + std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0}; + ComputeAndCompareR1<int32>(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); + builder.Max(x, y); + + std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9}; + ComputeAndCompareR1<int32>(&builder, expected, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<int32>({-3, 26, 2, -1, 1}); + auto b = builder.ConstantR1<int32>({10, 5, 1, 10, -10}); + auto add = builder.Rem(a, b); + + ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {}); +} + +TEST_F(ArrayElementwiseOpTest, NonNanClampF32) { + ComputationBuilder builder(client_, TestName()); + auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f}); + auto maximum = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + auto clamp = builder.Clamp(minimum, argument, maximum); + + ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) { + ComputationBuilder builder(client_, TestName()); + auto minimum = builder.ConstantR0<float>(0.0f); + auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto maximum = builder.ConstantR0<float>(5.0f); + auto clamp = builder.Clamp(minimum, argument, maximum); + + ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0<float>(0.0f); + auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); + auto arg_vector = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto arg_scalar = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); + auto max_scalar = builder.ConstantR0<float>(3.0f); + auto max_vector = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_scalar, max_vector), + builder.Clamp(min_scalar, arg_scalar, max_vector))); + + ComputeAndCompareR1<float>(&builder, {8.0f, 4.5f, 2.0f, 6.5f, 15.0f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr<Literal> param0_literal = + LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f}); + std::unique_ptr<GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + std::unique_ptr<Literal> param1_literal = + LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f}); + std::unique_ptr<GlobalData> param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto add = builder.Add(p0, p1); + + ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f}, + {param0_data.get(), param1_data.get()}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr<Literal> param0_literal = + LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0)); + std::unique_ptr<GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + std::unique_ptr<Literal> param1_literal = + LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0)); + std::unique_ptr<GlobalData> param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); + auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); + auto add = builder.Add(p0, p1); + + Array3D<float> expected(0, 7, 0); + ComputeAndCompareR3<float>( + &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr<Literal> param0_literal = + LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f}); + std::unique_ptr<GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + + auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f}); + auto p = builder.Parameter(0, param0_literal->shape(), "param0"); + auto add = builder.Add(a, p); + + ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f}, + {param0_data.get()}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, TanhF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f}); + auto result = builder.Tanh(a); + + ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { + // a ------ (add) --------- (add) + // / / + // b -----/ / + // c---------------------/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); + auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f}); + + auto add = builder.Add(a, b); + auto add2 = builder.Add(add, c); + + ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) { + // b ------ (add) --------- (add) + // / / + // c -----/ / + // a---------------------/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); + auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f}); + + auto add = builder.Add(b, c); + auto add2 = builder.Add(a, add); + + ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddWithNeg) { + // a ----- (neg) ----- (add) + // / + // b ----- (neg) ----/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); + + auto neg_a = builder.Neg(a); + auto neg_b = builder.Neg(b); + auto result = builder.Add(neg_a, neg_b); + + ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) { + // a ------ (add) ------------\ + // / \ + // b -----/ (add) + // / + // c ------ (add) ------------/ + // / + // d -----/ + ComputationBuilder builder(client_, TestName()); + + auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f}); + auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f}); + auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f}); + auto d = builder.ConstantR1<float>({-19.0f, 10.0f, -40.0f, 20.2f}); + + auto add_ab = builder.Add(a, b); + auto add_cd = builder.Add(c, d); + auto add_all = builder.Add(add_ab, add_cd); + + ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {}, + error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = + builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + auto add = builder.Add(a, b); + + Array2D<float> expected_array( + {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) { + // Add a scalar + matrix. + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = builder.ConstantR0<float>(3.0f); + auto add = builder.Add(scalar, a); + + Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) { + // Add a matrix + scalar. + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto scalar = builder.ConstantR0<float>(3.0f); + auto add = builder.Add(a, scalar); + + Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { + // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches + // only dim 0 of the matrix. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<float>({20.0f, 40.0f, 60.0f}); + // clang-format off + auto m = builder.ConstantR2<float>({ + {-2.5f, 3.14f, 1.0f}, + {2.25f, -10.0f, 3.33f}}); + // clang-format on + auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + Array2D<float> expected_array( + {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { + // Test broadcasting in Eq comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<int32>({42, 73}); + auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}}); + + // This test exercises both possible broadcast dimensions for a vector/matrix + // comparison. + auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1}); + auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); + auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); + + auto expected = LiteralUtil::MakeTuple( + {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(), + LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()}); + ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { + // Test broadcasting in Ne comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<int32>({42, 73}); + auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}}); + auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,2] { + { 00 }, + { 01 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { + // Test broadcasting in Ge comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); + auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 1100 }, + { 0001 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { + // Test broadcasting in Gt comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); + auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 0100 }, + { 0000 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { + // Test broadcasting in Le comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); + auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 1011 }, + { 1111 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { + // Test broadcasting in Lt comparison. + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<int32>({1, 2, 3, 4}); + auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}}); + auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1}); + + const string expected = R"(pred[2,4] { + { 0011 }, + { 1110 }, +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) { + // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op + // arguments is reversed. + ComputationBuilder builder(client_, TestName()); + auto m = builder.ConstantR2<float>({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}}); + auto v = builder.ConstantR1<float>({2.0f, 4.0f, 6.0f}); + auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1}); + Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) { + // Tests broadcasting for arrays with degenerate (size == 1) dimensions. + ComputationBuilder builder(client_, TestName()); + // m's shape in XLA notation is {3, 2} + // md's shape in XLA notation is {3, 1} + // The result has shape {3, 2}, where md is broadcast over m + auto m = + builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = builder.ConstantR2<float>({{10.0f, 20.0f, 30.0f}}); + auto add = builder.Add(m, md); + Array2D<float> expected_array( + {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) { + // Tests broadcasting for arrays with degenerate (size == 1) dimensions. + ComputationBuilder builder(client_, TestName()); + // m's shape in XLA notation is {3, 2} + // md's shape in XLA notation is {1, 2} + // The result has shape {3, 2}, where md is broadcast over m + auto m = + builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto md = builder.ConstantR2<float>({{10.0f}, {20.0f}}); + auto add = builder.Add(m, md); + Array2D<float> expected_array( + {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) { + // Tests broadcasting for two degenerate arrays. This kind of broadcasting + // effectively creates an "outer product" operation. + // This is taken from the Numpy docs example at: + // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html + ComputationBuilder builder(client_, TestName()); + // a's shape in XLA notation is {1, 4} + // b's shape in XLA notation is {3, 1} + // The result has shape {3, 4}. + auto a = builder.ConstantR2<float>({{0.0f}, {10.0f}, {20.0f}, {30.0f}}); + auto b = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}}); + auto add = builder.Add(a, b); + Array2D<float> expected_array({{1.0f, 2.0f, 3.0f}, + {11.0f, 12.0f, 13.0f}, + {21.0f, 22.0f, 23.0f}, + {31.0f, 32.0f, 33.0f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) { + // Add together a (2,2) array and a (2) array, using dimension 0 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<float>({20.0f, 40.0f}); + auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}}); + auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1}); + Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) { + // Add together a (2,2) array and a (2) array, using dimension 1 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + auto v = builder.ConstantR1<float>({20.0f, 40.0f}); + auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}}); + auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0}); + Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}}); + ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) { + // Binary add of two R3s together + ComputationBuilder builder(client_, TestName()); + Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); + auto a = builder.ConstantR3FromArray3D<float>(a_3d); + + Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}}, + {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}}); + auto b = builder.ConstantR3FromArray3D<float>(b_3d); + auto add = builder.Add(a, b); + + Array3D<float> expected_3d( + {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}}, + {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}}); + ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) { + // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + // clang-format off + Array3D<float> a_3d({ + {{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}}, + {{7.0f, 8.0f}, + {9.0f, 10.0f}, + {11.0f, 12.0f}}, + }); + // clang-format on + auto a = builder.ConstantR3FromArray3D<float>(a_3d); + auto v = builder.ConstantR1<float>({10.0f, 20.0f}); + auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2}); + + Array3D<float> expected_3d( + {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}}, + {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}}); + ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) { + // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for + // broadcasting (though there are two ways to broadcast these shapes). + ComputationBuilder builder(client_, TestName()); + // clang-format off + Array3D<float> a_3d({ + {{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}}, + {{7.0f, 8.0f}, + {9.0f, 10.0f}, + {11.0f, 12.0f}}, + }); + // clang-format on + auto a = builder.ConstantR3FromArray3D<float>(a_3d); + auto v = builder.ConstantR1<float>({10.0f, 20.0f}); + auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0}); + + // clang-format off + Array3D<float> expected_3d({ + {{11.0f, 12.0f}, + {13.0f, 14.0f}, + {15.0f, 16.0f}}, + {{27.0f, 28.0f}, + {29.0f, 30.0f}, + {31.0f, 32.0f}}, + }); + // clang-format on + ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) { + // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2} + // for broadcasting. + ComputationBuilder builder(client_, TestName()); + // clang-format off + Array3D<float> a_3d({ + {{1.0f, 2.0f}, + {3.0f, 4.0f}, + {5.0f, 6.0f}}, + {{7.0f, 8.0f}, + {9.0f, 10.0f}, + {11.0f, 12.0f}}, + }); + auto a = builder.ConstantR3FromArray3D<float>(a_3d); + auto m = builder.ConstantR2<float>({ + {10.0f, 20.0f, 30.0f}, + {40.0f, 50.0f, 60.0f}, + }); + auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1}); + + Array3D<float> expected_3d({ + {{11.0f, 12.0f}, + {23.0f, 24.0f}, + {35.0f, 36.0f}}, + {{47.0f, 48.0f}, + {59.0f, 60.0f}, + {71.0f, 72.0f}}, + }); + // clang-format on + ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) { + // Comparison between two 3D arrays of compatible shapes: + // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs. + ComputationBuilder builder(client_, TestName()); + Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}}, + {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}}); + auto a = builder.ConstantR3FromArray3D<float>(a_3d); + + Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}}); + auto b = builder.ConstantR3FromArray3D<float>(b_3d); + + auto compare = builder.Gt(a, b); + + Array3D<int> expected_3d( + {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}}); + const string expected = R"(pred[2,3,2] { +{ { 01 }, + { 00 }, + { 00 } }, +{ { 01 }, + { 10 }, + { 01 } } +})"; + EXPECT_EQ(expected, ExecuteToString(&builder, {})); +} + +TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); + std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5)); + std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5)); + float value = 0.0; + for (int64 p = 0; p < 2; ++p) { + for (int64 z = 0; z < 3; ++z) { + for (int64 y = 0; y < 4; ++y) { + for (int64 x = 0; x < 5; ++x) { + (*operand_a_4d)(p, z, y, x) = value; + (*operand_b_4d)(p, z, y, x) = 2.0 * value; + (*expected_4d)(p, z, y, x) = 3.0 * value; + value += 0.1; + } + } + } + } + + auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d); + auto b = builder.ConstantR4FromArray4D<float>(*operand_b_4d); + auto add = builder.Add(a, b); + + ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) { + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5)); + std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5)); + std::vector<float> operand_b_1d(3); + std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0); + + float value = 0.0; + for (int64 p = 0; p < 2; ++p) { + for (int64 z = 0; z < 3; ++z) { + for (int64 y = 0; y < 4; ++y) { + for (int64 x = 0; x < 5; ++x) { + (*operand_a_4d)(p, z, y, x) = value; + (*expected_4d)(p, z, y, x) = value + operand_b_1d[z]; + value += 0.1; + } + } + } + } + + auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d); + auto b = builder.ConstantR1<float>(operand_b_1d); + auto add = builder.Add(a, b, {1}); + + ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_); +} + +TEST_F(ArrayElementwiseOpTest, R4_32x64x2x2_Plus_R1_64) { + constexpr int d0 = 16; + constexpr int d1 = 16; + constexpr int d2 = 2; + constexpr int d3 = 2; + Array4D<float> r4(d0, d1, d2, d3); + r4.Fill(1.0); + std::vector<float> r1(d1); + std::iota(r1.begin(), r1.end(), 1.0); + + ComputationBuilder builder(client_, TestName()); + std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR4FromArray4D(r4); + *a_literal->mutable_shape()->mutable_layout() = + LayoutUtil::MakeLayout({0, 1, 2, 3}); + auto a = builder.ConstantLiteral(*a_literal); + auto b = builder.ConstantR1<float>(r1); + builder.Add(a, b, {1}); + + for (int i0 = 0; i0 < d0; ++i0) { + for (int i1 = 0; i1 < d1; ++i1) { + for (int i2 = 0; i2 < d2; ++i2) { + for (int i3 = 0; i3 < d3; ++i3) { + r4(i0, i1, i2, i3) += r1[i1]; + } + } + } + } + ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_); +} + +// Show that we can't add two opaques. +TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { + ComputationBuilder builder(client_, TestName()); + auto shape = ShapeUtil::MakeOpaqueShape(); + auto x = builder.Parameter(0, shape, "x"); + auto concatenated = builder.Add(x, x); + StatusOr<Computation> computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_MATCH(computation_status.status().ToString(), + testing::ContainsRegex( + "Expected non-opaque argument for lhs of binary operation")); +} + +// Regression test for b/31927799. "slice - y" is fused and requires implicit +// broadcast. +TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { + ComputationBuilder builder(client_, TestName()); + auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3}); + auto y_literal = LiteralUtil::CreateR1<float>({4, 5}); + auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); + auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); + + auto x = builder.Parameter(0, x_literal->shape(), "x"); + auto y = builder.Parameter(1, y_literal->shape(), "y"); + auto slice = builder.Slice(x, {1}, {2}); + builder.Sub(slice, y); + + ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()}, + error_spec_); +} + +INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, + ArrayElementwiseOpTestParamCount, + ::testing::Values(127, 128, 129, 17 * 4096)); + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector<tensorflow::Flag> flag_list; + xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); + xla::legacy_flags::AppendLlvmBackendFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} |