aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/scalar_computations_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/scalar_computations_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc630
1 files changed, 630 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
new file mode 100644
index 0000000000..bd9cae4d1d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -0,0 +1,630 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <limits>
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ScalarComputationsTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+
+ protected:
+ // A template for building and running a binary comparison test.
+ template <typename NativeT>
+ void TestCompare(NativeT lhs, NativeT rhs, bool expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<NativeT>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<NativeT>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<bool>(&builder, expected, {});
+ }
+
+ template <typename NativeT>
+ void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<NativeT>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<NativeT>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<NativeT>(&builder, expected, {});
+ }
+};
+
+TEST_F(ScalarComputationsTest, NegateScalarF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<float>(2.1f));
+
+ ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, NegateScalarS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -2, {});
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+
+ ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+
+ ComputeAndCompareR0<int32>(&builder, 7, {});
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<uint32>(35), builder.ConstantR0<uint32>(57));
+
+ ComputeAndCompareR0<uint32>(&builder, 92, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<uint8>(35), builder.ConstantR0<uint8>(57));
+
+ ComputeAndCompareR0<uint8>(&builder, 92, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
+ ComputationBuilder builder(client_, TestName());
+ const uint64 a = static_cast<uint64>(1) << 63;
+ const uint64 b = a + 1;
+ builder.Add(builder.ConstantR0<uint64>(a), builder.ConstantR0<uint64>(b));
+
+ ComputeAndCompareR0<uint64>(&builder, a + b, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
+ ComputationBuilder builder(client_, TestName());
+ const int64 a = static_cast<int64>(1) << 62;
+ const int64 b = a + 1;
+ builder.Add(builder.ConstantR0<int64>(a), builder.ConstantR0<int64>(b));
+
+ ComputeAndCompareR0<int64>(&builder, a + b, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<double>(0.25),
+ builder.ConstantR0<double>(3.5));
+
+ ComputeAndCompareR0<double>(&builder, 3.75, {});
+}
+
+TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Sub(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+
+ ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Sub(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+
+ ComputeAndCompareR0<int32>(&builder, -3, {});
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
+ builder.ConstantR0<float>(5.5f)),
+ builder.ConstantR0<float>(0.5f));
+
+ ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
+ std::vector<int32> data = {0,
+ 1,
+ -1,
+ 1234,
+ 0x1a243514,
+ std::numeric_limits<int32>::max(),
+ std::numeric_limits<int32>::min()};
+
+ for (int32 x : data) {
+ for (int32 y : data) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+
+ // Signed integer overflow is undefined behavior in C++. Convert the input
+ // integers to unsigned, perform the multiplication unsigned, and convert
+ // back.
+ int32 expected = static_cast<uint32>(x) * static_cast<uint32>(y);
+
+ ComputeAndCompareR0<int32>(&builder, expected, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
+ std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
+ 0x1a243514, 0xFFFFFFFF, 0x80808080};
+
+ for (uint32 x : data) {
+ for (uint32 y : data) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+
+ uint32 expected = x * y;
+ ComputeAndCompareR0<uint32>(&builder, expected, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(
+ builder.Mul(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)),
+ builder.ConstantR0<int32>(1));
+
+ ComputeAndCompareR0<int32>(&builder, 10, {});
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> b_data =
+ client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> c_data =
+ client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+
+ ComputationDataHandle a = builder.Parameter(0, a_literal->shape(), "a");
+ ComputationDataHandle b = builder.Parameter(1, b_literal->shape(), "b");
+ ComputationDataHandle c = builder.Parameter(2, c_literal->shape(), "c");
+ builder.Mul(builder.Mul(a, b), c);
+
+ ComputeAndCompareR0<float>(&builder, 5.775f,
+ {a_data.get(), b_data.get(), c_data.get()},
+ error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Div(builder.ConstantR0<float>(5.0f), builder.ConstantR0<float>(2.5f));
+
+ ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<float>(2.5f), builder.ConstantR0<float>(5.0f));
+
+ ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Div(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -2, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -1, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(INT_MIN),
+ builder.ConstantR0<int32>(7919));
+
+ ComputeAndCompareR0<int32>(&builder, -1309, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(INT_MIN),
+ builder.ConstantR0<int32>(INT_MAX));
+
+ ComputeAndCompareR0<int32>(&builder, -1, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
+ builder.Rem(x, builder.ConstantR0<int32>(80000));
+
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
+ TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal));
+ ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
+}
+
+XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
+ ComputationBuilder builder(client_, TestName());
+ // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
+ // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
+ builder.Div(builder.ConstantR0<uint32>(0xFFFFFFFE),
+ builder.ConstantR0<uint32>(2));
+
+ ComputeAndCompareR0<uint32>(&builder, 0x7FFFFFFF, {});
+}
+
+TEST_F(ScalarComputationsTest, LogicalAnd) {
+ for (bool x : {false, true}) {
+ for (bool y : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalAnd(builder.ConstantR0<bool>(x),
+ builder.ConstantR0<bool>(y));
+
+ ComputeAndCompareR0<bool>(&builder, x && y, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, LogicalOr) {
+ for (bool x : {false, true}) {
+ for (bool y : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalOr(builder.ConstantR0<bool>(x),
+ builder.ConstantR0<bool>(y));
+
+ ComputeAndCompareR0<bool>(&builder, x || y, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, LogicalNot) {
+ for (bool x : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalNot(builder.ConstantR0<bool>(x));
+
+ ComputeAndCompareR0<bool>(&builder, !x, {});
+ }
+}
+
+TEST_F(ScalarComputationsTest, SelectScalarTrue) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Select(builder.ConstantR0<bool>(true), // The predicate.
+ builder.ConstantR0<float>(123.0f), // The value on true.
+ builder.ConstantR0<float>(42.0f)); // The value on false.
+
+ ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, SelectScalarFalse) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Select(builder.ConstantR0<bool>(false), // The predicate.
+ builder.ConstantR0<float>(123.0f), // The value on true.
+ builder.ConstantR0<float>(42.0f)); // The value on false.
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
+}
+
+// This test is an explicit version of what is happening in the following
+// templatized comparison tests.
+TEST_F(ScalarComputationsTest, CompareGtScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Gt(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(1.0f));
+
+ ComputeAndCompareR0<bool>(&builder, true, {});
+}
+
+// S32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
+ TestCompare<int32>(2, 1, false, &ComputationBuilder::Eq);
+}
+TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
+ TestCompare<int32>(3, 3, true, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeS32) {
+ TestCompare<int32>(2, 1, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeS32) {
+ TestCompare<int32>(2, 1, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtS32) {
+ TestCompare<int32>(1, 5, false, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeS32) {
+ TestCompare<int32>(2, 1, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtS32) {
+ TestCompare<int32>(9, 7, false, &ComputationBuilder::Lt);
+ TestCompare<int32>(std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max(), true,
+ &ComputationBuilder::Lt);
+}
+
+// U32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqU32False) {
+ TestCompare<uint32>(2, 1, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeU32) {
+ TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
+ TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
+ TestCompare<uint32>(3, 3, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtU32) {
+ TestCompare<uint32>(1, 5, false, &ComputationBuilder::Gt);
+ TestCompare<uint32>(5, 5, false, &ComputationBuilder::Gt);
+ TestCompare<uint32>(5, 1, true, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeU32) {
+ TestCompare<uint32>(2, 1, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtU32) {
+ TestCompare<uint32>(9, 7, false, &ComputationBuilder::Lt);
+ TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
+ &ComputationBuilder::Lt);
+}
+
+// F32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqF32False) {
+ TestCompare<float>(2.0, 1.3, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeF32) {
+ TestCompare<float>(2.0, 1.3, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
+ TestCompare<float>(2.0, 1.9, true, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
+ TestCompare<float>(3.5, 3.5, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtF32) {
+ TestCompare<float>(1.0, 5.2, false, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeF32) {
+ TestCompare<float>(2.0, 1.2, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtF32) {
+ TestCompare<float>(9.0, 7.2, false, &ComputationBuilder::Lt);
+}
+
+// F32 comparisons with exceptional values. The test names encode the
+// left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
+TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
+ TestCompare<float>(-INFINITY, -0.0, true, &ComputationBuilder::Lt);
+}
+TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
+ // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
+ TestCompare<float>(-0.0, 0.0, false, &ComputationBuilder::Lt);
+}
+TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
+ TestCompare<float>(0.0, INFINITY, true, &ComputationBuilder::Lt);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
+ TestCompare<float>(-INFINITY, -0.0, false, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
+ // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
+ TestCompare<float>(-0.0, 0.0, true, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
+ TestCompare<float>(0.0, INFINITY, false, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, ExpScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Exp(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, LogScalar) {
+ ComputationBuilder builder(client_, "log");
+ builder.Log(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, TanhScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tanh(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tanh(builder.ConstantR0<double>(2.0));
+
+ ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, PowScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Pow(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(3.0f));
+
+ ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarHigh) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(5.0f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(2.5f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarLow) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(-5.0f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, MinS32Above) {
+ TestMinMax<int32>(10, 3, 3, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinS32Below) {
+ TestMinMax<int32>(-100, 3, -100, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxS32Above) {
+ TestMinMax<int32>(10, 3, 10, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxS32Below) {
+ TestMinMax<int32>(-100, 3, 3, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MinU32Above) {
+ const uint32 large = std::numeric_limits<int32>::max();
+ TestMinMax<uint32>(large, 3, 3, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinU32Below) {
+ TestMinMax<uint32>(0, 5, 0, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxU32Above) {
+ const uint32 large = std::numeric_limits<int32>::max();
+ TestMinMax<uint32>(large, 3, large, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxU32Below) {
+ TestMinMax<uint32>(0, 5, 5, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MinF32Above) {
+ TestMinMax<float>(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinF32Below) {
+ TestMinMax<float>(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxF32Above) {
+ TestMinMax<float>(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxF32Below) {
+ TestMinMax<float>(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
+ // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
+ ComputationBuilder b(client_, TestName());
+ b.Div(
+ b.Sub(b.Mul(b.ConstantR0<float>(1),
+ b.Mul(b.Sub(b.ConstantR0<float>(3), b.ConstantR0<float>(1)),
+ b.Add(b.ConstantR0<float>(7), b.ConstantR0<float>(0)))),
+ b.ConstantR0<float>(4)),
+ b.ConstantR0<float>(20));
+
+ ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
+ // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
+ ComputationBuilder b(client_, TestName());
+ b.Sub(b.Mul(b.ConstantR0<int32>(1),
+ b.Mul(b.Sub(b.ConstantR0<int32>(3), b.ConstantR0<int32>(1)),
+ b.Add(b.ConstantR0<int32>(7), b.ConstantR0<int32>(0)))),
+ b.ConstantR0<int32>(4));
+
+ ComputeAndCompareR0<int32>(&b, 10, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendLlvmBackendFlags(&flag_list);
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}