aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/tests/BUILD32
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc75
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h1
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc10
4 files changed, 118 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index c64d5aca4f..2e220e7293 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -770,6 +770,38 @@ xla_test(
)
xla_test(
+ name = "bfloat16_test",
+ srcs = ["bfloat16_test.cc"],
+ shard_count = 40,
+ deps = [
+ ":test_utils",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
name = "slice_test",
srcs = ["slice_test.cc"],
shard_count = 40,
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
new file mode 100644
index 0000000000..26e2b1a95b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class Bfloat16Test : public ClientLibraryTestBase {
+ protected:
+ const ErrorSpec error_spec_{0.001, 0.001};
+};
+
+XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
+ DISABLED_ON_CPU(ScalarOperation)))) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
+ auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
+ builder.Add(x, y);
+
+ ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {},
+ error_spec_);
+}
+
+XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
+ DISABLED_ON_CPU(NegateScalarF16)))) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
+
+ ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {},
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 1dc274c591..af22c12684 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -333,6 +333,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 75c9a0d3fb..9ae5c7b6f0 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -340,6 +340,9 @@ class NearComparator {
multi_index_.resize(expected.shape().dimensions_size(), 0);
switch (expected.shape().element_type()) {
+ case BF16:
+ ExpectLiteralsNear<bfloat16>(expected, actual, 0);
+ break;
case F32:
ExpectLiteralsNear<float>(expected, actual, 0);
break;
@@ -525,6 +528,13 @@ void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual,
<< message;
}
+template <>
+bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
+ bfloat16 actual) {
+ return ExpectValuesNear(static_cast<float>(expected),
+ static_cast<float>(actual));
+}
+
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(