aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2017-11-21 15:10:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-21 15:14:24 -0800
commit34a96722c9d3ee53ed3be9db5522307637877d29 (patch)
tree2dc5fd328bce8c910261fcb28bd087f8189b6e8a
parentdb8447528c1f7d6055d9a0145aa35bbea7bfd810 (diff)
Add the first e2e scalar test with bfloat16.
This test doesn't pass yet, but it's good to use it to drive future development work. PiperOrigin-RevId: 176568226
-rw-r--r--tensorflow/compiler/xla/tests/BUILD32
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc75
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h1
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc10
4 files changed, 118 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index c64d5aca4f..2e220e7293 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -770,6 +770,38 @@ xla_test(
)
xla_test(
+ name = "bfloat16_test",
+ srcs = ["bfloat16_test.cc"],
+ shard_count = 40,
+ deps = [
+ ":test_utils",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
name = "slice_test",
srcs = ["slice_test.cc"],
shard_count = 40,
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
new file mode 100644
index 0000000000..26e2b1a95b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class Bfloat16Test : public ClientLibraryTestBase {
+ protected:
+ const ErrorSpec error_spec_{0.001, 0.001};
+};
+
+XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
+ DISABLED_ON_CPU(ScalarOperation)))) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
+ auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
+ builder.Add(x, y);
+
+ ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {},
+ error_spec_);
+}
+
+XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
+ DISABLED_ON_CPU(NegateScalarF16)))) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
+
+ ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {},
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 1dc274c591..af22c12684 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -333,6 +333,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
+ std::is_same<NativeT, bfloat16>::value ||
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 75c9a0d3fb..9ae5c7b6f0 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -340,6 +340,9 @@ class NearComparator {
multi_index_.resize(expected.shape().dimensions_size(), 0);
switch (expected.shape().element_type()) {
+ case BF16:
+ ExpectLiteralsNear<bfloat16>(expected, actual, 0);
+ break;
case F32:
ExpectLiteralsNear<float>(expected, actual, 0);
break;
@@ -525,6 +528,13 @@ void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual,
<< message;
}
+template <>
+bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected,
+ bfloat16 actual) {
+ return ExpectValuesNear(static_cast<float>(expected),
+ static_cast<float>(actual));
+}
+
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(