From 552d7924ac9bb6cd2643dd49bda04a735c65f5db Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Tue, 14 Aug 2018 17:48:38 -0700 Subject: Add a buffer comparator to make it easy comparing fp16 buffers. PiperOrigin-RevId: 208747589 --- tensorflow/compiler/xla/service/gpu/BUILD | 33 ++++ .../compiler/xla/service/gpu/buffer_comparator.cc | 205 +++++++++++++++++++++ .../compiler/xla/service/gpu/buffer_comparator.h | 71 +++++++ .../xla/service/gpu/buffer_comparator_test.cc | 126 +++++++++++++ 4 files changed, 435 insertions(+) create mode 100644 tensorflow/compiler/xla/service/gpu/buffer_comparator.cc create mode 100644 tensorflow/compiler/xla/service/gpu/buffer_comparator.h create mode 100644 tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index bacd2c1f14..02c11335c4 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,7 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") licenses(["notice"]) # Apache 2.0 @@ -853,3 +854,35 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "buffer_comparator", + srcs = ["buffer_comparator.cc"], + hdrs = ["buffer_comparator.h"], + deps = [ + ":gpu_executable", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +xla_test( + name = "buffer_comparator_test", + srcs = ["buffer_comparator_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":buffer_comparator", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc new file mode 100644 index 0000000000..6a285a6b98 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace gpu { + +static constexpr float kTolerance = 0.1f; + +static string GetCompHloText(size_t num_elements) { + // Implements the textual format of the comparison routine, as it's more + // readable. + static constexpr char kF16CompHloText[] = R"( +HloModule CompareF16 + +MaxF32 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %max = f32[] maximum(%lhs, %rhs) +} + +Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] { + %min_constant = f32[] constant(-65505) + %max_constant = f32[] constant(65505) + %large_constant = f32[] constant(1048576) + %min_values = f32[SIZE] broadcast(%min_constant), dimensions={} + %max_values = f32[SIZE] broadcast(%max_constant), dimensions={} + %large_values = f32[SIZE] broadcast(%large_constant), dimensions={} + + %a = f16[SIZE] parameter(0) + %converted = f32[SIZE] convert(%a) + %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values) + + // Since the clamp() above already took care of infs, only NaNs will cause + // is-finite() to return false. + %is_finite = pred[SIZE] is-finite(%clamped) + ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values) +} + +ENTRY MaxDifference { + %one_constant = f32[] constant(1.0) + %zero_constant = f32[] constant(0.0) + + %ones = f32[SIZE] broadcast(%one_constant), dimensions={} + + %lhs = f16[SIZE] parameter(0) + %rhs = f16[SIZE] parameter(1) + %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize + %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize + %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical) + %sub_abs = f32[SIZE] abs(%sub) + %lhs_abs = f32[SIZE] abs(%lhs_canonical) + %rhs_abs = f32[SIZE] abs(%rhs_canonical) + %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs) + %denominator = f32[SIZE] add(%max, %ones) + %error = f32[SIZE] divide(%sub_abs, %denominator) + ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 +})"; + auto size_string = std::to_string(num_elements); + return tensorflow::str_util::StringReplace( + kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true); +} + +StatusOr F16BufferComparator::Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream) { + auto stream_exec = stream->parent(); + int64 num_elements = ref_buffer.ElementCount(); + + // One may consider using hlo_runner to do all the compilation and execution. + // However, as of the time hlo_runner doesn't support injection for Compiler*, + // Stream*, or even the allocator. We may revisit this in the future if it + // proves to be a maintenance burden. + TF_ASSIGN_OR_RETURN( + auto exec, ([&]() -> StatusOr> { + HloModuleConfig config; + DebugOptions debug_options; + debug_options.set_xla_backend_optimization_level(2); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN( + auto module, ParseHloString(GetCompHloText(num_elements), config)); + TF_ASSIGN_OR_RETURN( + module, + compiler->RunHloPasses(std::move(module), stream_exec, nullptr)); + return compiler->RunBackend(std::move(module), stream_exec, nullptr); + }())); + + TF_ASSIGN_OR_RETURN( + auto shaped_buffer, ([&]() -> StatusOr { + auto device_ordinal = stream_exec->device_ordinal(); + TF_ASSIGN_OR_RETURN( + auto owning_buffer, + allocator->Allocate(device_ordinal, ref_buffer.size())); + se::DeviceMemory buffer( + owning_buffer.AsDeviceMemoryBase()); + stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size()); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal); + ret.set_buffer(std::move(owning_buffer), {}); + return std::move(ret); + }())); + + return F16BufferComparator(stream, allocator, std::move(exec), + std::move(shaped_buffer)); +} + +StatusOr F16BufferComparator::CompareEqualImpl( + se::DeviceMemory test_buffer) { + if (ref_buffer_.root_buffer().size() != test_buffer.size()) { + return InternalError("Mismatched buffer size: %lld vs %lld", + ref_buffer_.root_buffer().size(), test_buffer.size()); + } + + int64 num_elements = test_buffer.ElementCount(); + + TF_ASSIGN_OR_RETURN( + auto result_buffer, ([&]() -> StatusOr { + auto stream_exec = stream_->parent(); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + auto device_ordinal = stream_exec->device_ordinal(); + ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(), + device_ordinal); + shaped_test_buffer.set_buffer(test_buffer, {}); + ExecutableRunOptions run_options; + run_options.set_device_ordinal(stream_exec->device_ordinal()); + run_options.set_stream(stream_); + run_options.set_allocator(allocator_); + ServiceExecutableRunOptions service_run_options(run_options); + return exec_->ExecuteOnStream( + &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr); + }())); + + float result; + CHECK(result_buffer.root_buffer().size() == sizeof(result)); + stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result)); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + return result < kTolerance; +} + +StatusOr F16BufferComparator::CompareEqual( + se::DeviceMemory test_buffer) { + TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer)); + if (result) { + return true; + } + // Host side code that does the same thing, but report some of the + // differences as well. + int64 n = test_buffer.ElementCount(); + std::vector host_ref_buffer(n), host_test_buffer(n); + stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(), + ref_buffer_.root_buffer().size()); + stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size()); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + + const auto canonicalize = [](float a) -> float { + constexpr float kBigNumer = 1048576.; + constexpr float kMaxFp16Value = 65504.; + if (std::isnan(a)) { + return kBigNumer; + } + if (std::isinf(a)) { + if (a < 0) { + return -(kMaxFp16Value + 1); + } + return kMaxFp16Value + 1; + } + return a; + }; + int differences_seen = 0; + for (int64 i = 0; i < n && differences_seen < 10; i++) { + float original_ref = static_cast(host_ref_buffer[i]); + float original_test = static_cast(host_test_buffer[i]); + float ref = canonicalize(original_ref); + float test = canonicalize(original_test); + if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) < + kTolerance)) { + differences_seen++; + LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs " + << original_test; + } + } + + return false; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h new file mode 100644 index 0000000000..bf2ba78cea --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A fp16 comparator that internally keeps a reference buffer, and compares it +// against other test buffers. +class F16BufferComparator { + public: + F16BufferComparator(const F16BufferComparator&) = delete; + F16BufferComparator(F16BufferComparator&&) = default; + + // Creates a new comparator. It internally allocates a buffer initialized by + // ref_buffer. + static StatusOr Create( + se::DeviceMemory ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream); + + // Returns true if the internally allocated buffer "compares equal" to + // test_buffer. The definition of "equal" is: + // * All NaNs equal. + // * All infs are treated as 65505 or -65505, so that this checker is tolerant + // to fp16 overflows. + // * With NaNs and infs taken care of, a and b compare equal iff: + // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance + // + // See the implementation for the tolerance value. + StatusOr CompareEqual(se::DeviceMemory test_buffer); + + private: + F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator, + std::unique_ptr exec, + ScopedShapedBuffer ref_buffer) + : stream_(stream), + allocator_(allocator), + exec_(std::move(exec)), + ref_buffer_(std::move(ref_buffer)) {} + + StatusOr CompareEqualImpl(se::DeviceMemory test_buffer); + + se::Stream* stream_; + DeviceMemoryAllocator* allocator_; + std::unique_ptr exec_; + ScopedShapedBuffer ref_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc new file mode 100644 index 0000000000..33761d1bd8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class BufferComparatorTest : public testing::Test { + protected: + BufferComparatorTest() + : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()), + stream_exec_(backend_->default_stream_executor()), + allocator_(stream_exec_->platform(), {stream_exec_}), + compiler_(Compiler::GetForPlatform(stream_exec_->platform()) + .ConsumeValueOrDie()) {} + + // Take floats only for convenience. Still uses half internally. + bool CompareEqualFloatBuffers(const std::vector& lhs_float, + const std::vector& rhs_float) { + std::vector lhs(lhs_float.begin(), lhs_float.end()); + std::vector rhs(rhs_float.begin(), rhs_float.end()); + se::Stream stream(stream_exec_); + stream.Init(); + + auto owning_lhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto owning_rhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto lhs_buffer = + se::DeviceMemory(owning_lhs_buffer.AsDeviceMemoryBase()); + auto rhs_buffer = + se::DeviceMemory(owning_rhs_buffer.AsDeviceMemoryBase()); + + stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size()); + stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size()); + + TF_CHECK_OK(stream.BlockHostUntilDone()); + + return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_, + &stream) + .ConsumeValueOrDie() + .CompareEqual(rhs_buffer) + .ConsumeValueOrDie(); + } + + std::unique_ptr backend_; + se::StreamExecutor* stream_exec_; + StreamExecutorMemoryAllocator allocator_; + Compiler* compiler_; +}; + +TEST_F(BufferComparatorTest, TestNaNs) { + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); +} + +TEST_F(BufferComparatorTest, TestInfs) { + const auto inf = std::numeric_limits::infinity(); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); +} + +TEST_F(BufferComparatorTest, TestNumbers) { + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); +} + +TEST_F(BufferComparatorTest, TestMultiple) { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } +} + +} // namespace +} // namespace gpu +} // namespace xla -- cgit v1.2.3