diff options
author | 2017-11-21 15:10:21 -0800 | |
---|---|---|
committer | 2017-11-21 15:14:24 -0800 | |
commit | 34a96722c9d3ee53ed3be9db5522307637877d29 (patch) | |
tree | 2dc5fd328bce8c910261fcb28bd087f8189b6e8a | |
parent | db8447528c1f7d6055d9a0145aa35bbea7bfd810 (diff) |
Add the first e2e scalar test with bfloat16.
This test doesn't pass yet, but it's good to use it to drive future development work.
PiperOrigin-RevId: 176568226
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/bfloat16_test.cc | 75 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/client_library_test_base.h | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.cc | 10 |
4 files changed, 118 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index c64d5aca4f..2e220e7293 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -770,6 +770,38 @@ xla_test( ) xla_test( + name = "bfloat16_test", + srcs = ["bfloat16_test.cc"], + shard_count = 40, + deps = [ + ":test_utils", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc new file mode 100644 index 0000000000..26e2b1a95b --- /dev/null +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <cmath> +#include <memory> +#include <vector> + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class Bfloat16Test : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{0.001, 0.001}; +}; + +XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( + DISABLED_ON_CPU(ScalarOperation)))) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f)); + auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f)); + builder.Add(x, y); + + ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( + DISABLED_ON_CPU(NegateScalarF16)))) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f))); + + ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {}, + error_spec_); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 1dc274c591..af22c12684 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -333,6 +333,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) { static_assert(std::is_same<NativeT, float>::value || std::is_same<NativeT, double>::value || + std::is_same<NativeT, bfloat16>::value || std::is_same<NativeT, complex64>::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr<Literal> expected_literal = diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 75c9a0d3fb..9ae5c7b6f0 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -340,6 +340,9 @@ class NearComparator { multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { + case BF16: + ExpectLiteralsNear<bfloat16>(expected, actual, 0); + break; case F32: ExpectLiteralsNear<float>(expected, actual, 0); break; @@ -525,6 +528,13 @@ void NearComparator::ExpectNear<complex64>(complex64 expected, complex64 actual, << message; } +template <> +bool NearComparator::ExpectValuesNear<bfloat16>(bfloat16 expected, + bfloat16 actual) { + return ExpectValuesNear(static_cast<float>(expected), + static_cast<float>(actual)); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( |