aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/math_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/lib/math_test.cc')
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc140
1 files changed, 140 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
new file mode 100644
index 0000000000..1df287d7db
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class MathTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(MathTest, SqrtF32) {
+ XlaBuilder builder(TestName());
+ Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32);
+
+ std::unique_ptr<GlobalData> zero_data =
+ client_->TransferToServer(zero_literal).ConsumeValueOrDie();
+
+ XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero");
+ Sqrt(zero);
+
+ ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SquareTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Square(x);
+
+ std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
+ 5.29, 25., 0.81, 5.76, 2.56};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, ReciprocalTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Reciprocal(x);
+
+ std::vector<float> expected = {
+ 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
+ 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtZeroes) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {0.0, -0.0});
+ Sqrt(x);
+
+ ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtSixValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
+ Sqrt(x);
+
+ std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, Lgamma) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.5, 1.5,
+ 2.5, -1.5, -3.5, -5.5});
+ Lgamma(x);
+
+ std::vector<float> expected = {
+ 0,
+ 0,
+ static_cast<float>(std::log(2)),
+ static_cast<float>(std::log(6)),
+ static_cast<float>(std::log(24)),
+ static_cast<float>(std::log(120)),
+ static_cast<float>(std::log(M_PI) / 2),
+ static_cast<float>(std::log(M_PI) / 2 - std::log(2)),
+ static_cast<float>(std::log(M_PI) / 2 - std::log(4) + std::log(3)),
+ static_cast<float>(std::log(M_PI) / 2 - std::log(3) + std::log(4)),
+ static_cast<float>(std::log(M_PI) / 2 - std::log(105) + std::log(16)),
+ static_cast<float>(std::log(M_PI) / 2 - std::log(10395) + std::log(64))};
+ error_spec_ = ErrorSpec{0.001};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, Digamma) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {1.0, 0.5, 1 / 3.0, 0.25, 1 / 6.0, 0.125,
+ 2.0, 3.0, 4.0, 6.0, 8.0, 9.0});
+ Digamma(x);
+
+ constexpr double euler_mascheroni =
+ 0.57721566490153286060651209008240243104215933593992;
+ std::vector<float> expected = {
+ static_cast<float>(-euler_mascheroni),
+ static_cast<float>(-2 * std::log(2) - euler_mascheroni),
+ static_cast<float>(-M_PI / 2 / std::sqrt(3) - 3 * std::log(3) / 2 -
+ euler_mascheroni),
+ static_cast<float>(-M_PI / 2 - 3 * std::log(2) - euler_mascheroni),
+ static_cast<float>(-M_PI * std::sqrt(3) / 2 - 2 * std::log(2) -
+ 3 * std::log(3) / 2 - euler_mascheroni),
+ static_cast<float>(
+ -M_PI / 2 - 4 * std::log(2) -
+ (M_PI + std::log(2 + std::sqrt(2)) - std::log(2 - std::sqrt(2))) /
+ std::sqrt(2) -
+ euler_mascheroni),
+ static_cast<float>(1 - euler_mascheroni),
+ static_cast<float>(1.5 - euler_mascheroni),
+ static_cast<float>(11 / 6.0 - euler_mascheroni),
+ static_cast<float>(137 / 60.0 - euler_mascheroni),
+ static_cast<float>(363 / 140.0 - euler_mascheroni),
+ static_cast<float>(761 / 280.0 - euler_mascheroni)};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+} // namespace
+} // namespace xla