From 34a96722c9d3ee53ed3be9db5522307637877d29 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Tue, 21 Nov 2017 15:10:21 -0800 Subject: Add the first e2e scalar test with bfloat16. This test doesn't pass yet, but it's good to use it to drive future development work. PiperOrigin-RevId: 176568226 --- tensorflow/compiler/xla/tests/BUILD | 32 +++++++++ tensorflow/compiler/xla/tests/bfloat16_test.cc | 75 ++++++++++++++++++++++ .../compiler/xla/tests/client_library_test_base.h | 1 + tensorflow/compiler/xla/tests/literal_test_util.cc | 10 +++ 4 files changed, 118 insertions(+) create mode 100644 tensorflow/compiler/xla/tests/bfloat16_test.cc diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index c64d5aca4f..2e220e7293 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -769,6 +769,38 @@ xla_test( ], ) +xla_test( + name = "bfloat16_test", + srcs = ["bfloat16_test.cc"], + shard_count = 40, + deps = [ + ":test_utils", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "slice_test", srcs = ["slice_test.cc"], diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc new file mode 100644 index 0000000000..26e2b1a95b --- /dev/null +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class Bfloat16Test : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{0.001, 0.001}; +}; + +XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( + DISABLED_ON_CPU(ScalarOperation)))) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(2.0f)); + auto y = builder.ConstantR0(static_cast(1.0f)); + builder.Add(x, y); + + ComputeAndCompareR0(&builder, static_cast(3.0f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( + DISABLED_ON_CPU(NegateScalarF16)))) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0(static_cast(2.1f))); + + ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, + error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 1dc274c591..af22c12684 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -333,6 +333,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 75c9a0d3fb..9ae5c7b6f0 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -340,6 +340,9 @@ class NearComparator { multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { + case BF16: + ExpectLiteralsNear(expected, actual, 0); + break; case F32: ExpectLiteralsNear(expected, actual, 0); break; @@ -525,6 +528,13 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, << message; } +template <> +bool NearComparator::ExpectValuesNear(bfloat16 expected, + bfloat16 actual) { + return ExpectValuesNear(static_cast(expected), + static_cast(actual)); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( -- cgit v1.2.3