aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/hlo_test_base.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/hlo_test_base.cc')
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc204
1 files changed, 204 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
new file mode 100644
index 0000000000..872188de81
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+#include <set>
+#include <string>
+#include <utility>
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+// Define this in .cc file to avoid having to include eigen or forward declare
+// these types in the header.
+struct HloTestBase::EigenThreadPoolWrapper {
+ std::unique_ptr<EigenThreadPoolWrapper> pool;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
+};
+
+HloTestBase::HloTestBase()
+ : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) {
+ test_hlo_dumper_ = [](const HloModule& module, const string& label) {
+ legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags();
+ if (flags->xla_hlo_test_generate_hlo_graph) {
+ const bool show_addresses = true;
+ const bool show_layouts = true;
+ hlo_graph_dumper::DumpGraph(*module.entry_computation(), label,
+ show_addresses, show_layouts);
+ }
+ };
+ VLOG(1) << "executing on platform " << backend_->platform()->Name();
+}
+
+HloTestBase::~HloTestBase() {
+ // Deallocate all the memory allocated during the tests.
+ for (auto& allocation : allocations_) {
+ backend_->default_stream_executor()->Deallocate(&allocation);
+ }
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Shape* result_shape) {
+ auto module_config = MakeUnique<HloModuleConfig>(
+ MakeProgramShape(module->entry_computation()));
+ return Execute(std::move(module), std::move(module_config), arguments,
+ result_shape);
+}
+
+StatusOr<se::DeviceMemoryBase> HloTestBase::Execute(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ Shape* result_shape) {
+ VLOG(3) << "module_config layout "
+ << LayoutUtil::HumanString(module_config->entry_computation_layout()
+ .result_layout()
+ .layout());
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable,
+ backend_->compiler()->Compile(std::move(hlo_module),
+ std::move(module_config), test_hlo_dumper_,
+ backend_->default_stream_executor()));
+
+ se::Stream stream(backend_->default_stream_executor());
+ stream.Init();
+
+ ExecutableRunOptions run_options;
+ run_options.set_stream(&stream);
+ run_options.set_allocator(backend_->memory_allocator());
+ run_options.set_inter_op_thread_pool(backend_->inter_op_thread_pool());
+ run_options.set_intra_op_thread_pool(
+ backend_->eigen_intra_op_thread_pool_device());
+
+ HloExecutionProfile hlo_execution_profile;
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result,
+ executable->ExecuteOnStream(&run_options, arguments,
+ &hlo_execution_profile));
+ TF_RET_CHECK(stream.BlockHostUntilDone());
+
+ allocations_.push_back(result);
+
+ *result_shape = executable->result_shape();
+
+ if (ShapeUtil::IsTuple(*result_shape)) {
+ // We must record element buffers of tuples as well to avoid leaks.
+ DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ backend_->transfer_manager()->ShallowCopyTupleFromDevice(
+ backend_->default_stream_executor(), result, *result_shape));
+
+ // A tuple may contain the same buffer in more than one element. Keep track
+ // of the buffers already added to avoid duplicates in allocations_.
+ std::set<void*> added_opaques;
+ for (auto element_buffer : element_buffers) {
+ if (added_opaques.count(element_buffer.opaque()) == 0) {
+ added_opaques.insert(element_buffer.opaque());
+ allocations_.push_back(element_buffer);
+ }
+ }
+ }
+
+ return result;
+}
+
+se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
+ // Allocate memory on the device using the stream executor.
+ int64 allocation_size =
+ backend_->transfer_manager()->GetByteSizeRequirement(literal.shape());
+ se::DeviceMemoryBase allocation =
+ backend_->default_stream_executor()->AllocateArray<uint8>(
+ allocation_size);
+ allocations_.push_back(allocation);
+
+ TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralToDevice(
+ backend_->default_stream_executor(), literal, &allocation));
+
+ return allocation;
+}
+
+std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
+ const Shape& shape, se::DeviceMemoryBase device_base) {
+ auto literal = MakeUnique<Literal>();
+ TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralFromDevice(
+ backend_->default_stream_executor(), device_base, shape, shape,
+ literal.get()));
+ return literal;
+}
+
+std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ Shape result_shape;
+ se::DeviceMemoryBase device_base =
+ Execute(std::move(module), arguments, &result_shape).ValueOrDie();
+ return TransferFromDevice(result_shape, device_base);
+}
+
+std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ Shape result_shape;
+ se::DeviceMemoryBase device_base =
+ Execute(std::move(module), std::move(module_config), arguments,
+ &result_shape)
+ .ValueOrDie();
+ return TransferFromDevice(result_shape, device_base);
+}
+
+ProgramShape HloTestBase::MakeProgramShape(HloComputation* computation) {
+ ProgramShape program_shape;
+ for (int64 i = 0; i < computation->num_parameters(); ++i) {
+ *program_shape.add_parameters() =
+ computation->parameter_instruction(i)->shape();
+ }
+ *program_shape.mutable_result() = computation->root_instruction()->shape();
+ return program_shape;
+}
+
+string HloTestBase::TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+}
+
+} // namespace xla