diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_local_launch_op.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_local_launch_op.cc | 342 |
1 files changed, 342 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/xla_local_launch_op.cc b/tensorflow/compiler/jit/xla_local_launch_op.cc new file mode 100644 index 0000000000..7945e057cf --- /dev/null +++ b/tensorflow/compiler/jit/xla_local_launch_op.cc @@ -0,0 +1,342 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_local_launch_op.h" + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace gpu = perftools::gputools; + +namespace tensorflow { + +REGISTER_OP("_XlaLaunch") + .Input("constants: Tconstants") + .Attr("Tconstants: list(type) >= 0") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Output("results: Tresults") + .Attr("Tresults: list(type) >= 0") + .Attr("function: func") + .Doc("XLA Launch Op. For use by the XLA JIT only."); + +// Adapter class that wraps a Tensorflow allocator as an XLA allocator. +class XlaAllocator : public xla::DeviceMemoryAllocator { + public: + XlaAllocator(const perftools::gputools::Platform* platform, + OpKernelContext* op_context); + ~XlaAllocator() override; + xla::StatusOr<perftools::gputools::DeviceMemoryBase> Allocate( + int device_ordinal, uint64 size, bool retry_on_failure = true) override; + Status Deallocate(int device_ordinal, + perftools::gputools::DeviceMemoryBase* mem) override; + + // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is + // interpreted as having data type 'dtype' and shape 'shape'. + Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype, + const TensorShape& shape, Tensor* tensor) const; + + private: + OpKernelContext* const op_context_; + + // Map from pointer address to the owning Tensor; used by + // MakeTensorFromBuffer. Also used to automatically release Tensors when the + // allocator is freed. + std::unordered_map<void*, Tensor> tensors_; +}; + +XlaAllocator::XlaAllocator(const perftools::gputools::Platform* platform, + OpKernelContext* op_context) + : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} + +XlaAllocator::~XlaAllocator() = default; + +xla::StatusOr<perftools::gputools::DeviceMemoryBase> XlaAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + AllocatorAttributes allocator_attrs; + allocator_attrs.set_on_host(false); + + AllocationAttributes allocation_attrs; + allocation_attrs.no_retry_on_failure = !retry_on_failure; + + Tensor t; + Status status = op_context_->allocate_temp( + DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs, + allocation_attrs); + if (!status.ok()) { + VLOG(2) << "Allocation failed " << size; + return status; + } + void* data = + reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data())); + TF_RET_CHECK(data != nullptr); + tensors_[data] = t; + return perftools::gputools::DeviceMemoryBase(data, size); +} + +Status XlaAllocator::Deallocate(int device_ordinal, + perftools::gputools::DeviceMemoryBase* mem) { + if (mem->opaque() != nullptr) { + if (tensors_.erase(mem->opaque()) == 0) { + return tensorflow::errors::InvalidArgument("Unknown tensor address"); + } + } + return Status::OK(); +} + +Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, + DataType dtype, + const TensorShape& shape, + Tensor* out_tensor) const { + void* ptr = const_cast<void*>(buffer.opaque()); + auto it = tensors_.find(ptr); + if (it == tensors_.end()) { + return errors::InvalidArgument("Unknown tensor address"); + } + const Tensor& tensor = it->second; + + int64 output_size = DataTypeSize(dtype) * shape.num_elements(); + if (tensor.TotalBytes() == output_size) { + out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape); + } else { + Tensor slice = tensor.Slice(0, output_size); + out_tensor->UnsafeCopyFromInternal(slice, dtype, shape); + } + return Status::OK(); +} + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : OpKernel(ctx), device_type_(ctx->device_type()) { + const NameAttrList* func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); + function_ = *func; + DataTypeVector constant_types; + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); + num_constant_args_ = constant_types.size(); +} + +Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) { + gpu::Platform::Id platform_id; + if (device_type_ == DeviceType(DEVICE_CPU)) { + platform_id = gpu::host::kHostPlatformId; + } else if (device_type_ == DeviceType(DEVICE_GPU)) { + platform_id = gpu::cuda::kCudaPlatformId; + } else { + return errors::InvalidArgument("Unknown device type for local _XlaLaunch"); + } + + auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id); + if (!platform.ok()) { + return StreamExecutorUtil::ConvertStatus(platform.status()); + } + auto client = + xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()); + if (!client.ok()) { + return client.status(); + } + const string* compiler_device; + if (!XlaOpRegistry::GetJitDevice(device_type_.type(), &compiler_device, + /*requires_jit=*/nullptr)) { + return errors::InvalidArgument("No JIT device registered for ", + device_type_.type()); + } + XlaCompiler::Options options; + options.device_type = DeviceType(*compiler_device); + options.client = client.ValueOrDie(); + options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId); + options.local_executable_has_hybrid_result = true; + *compiler = new XlaCompilationCache(options); + return Status::OK(); +} + +void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOp::Compute " + << Canonicalize(function_.name(), function_.attr()); + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); + + gpu::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + XlaCompilationCache* compiler; + OP_REQUIRES_OK(ctx, + rm->LookupOrCreate<XlaCompilationCache>( + rm->default_container(), "xla_compiler", &compiler, + [this](XlaCompilationCache** compiler) { + return BuildCompilationCache(compiler); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref compiler_ref(compiler); + + xla::LocalClient* client = static_cast<xla::LocalClient*>(compiler->client()); + + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + OP_REQUIRES_OK(ctx, + compiler->Compile(function_, num_constant_args_, ctx, &kernel, + &executable)); + + VLOG(1) << "Executing XLA Computation..."; + + // Builds an XLA allocator for the device. + XlaAllocator xla_allocator(client->platform(), ctx); + XlaLocalRuntimeContext local_runtime_context; + + std::unique_ptr<xla::ShapedBuffer> output; + bool output_is_tuple; + if (!kernel->computation.IsNull()) { + // Build xla::ShapedBuffers that point directly to the Tensor buffers. + std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers; + arg_buffers.reserve(kernel->xla_input_shapes.size() + 1); + arg_buffers.resize(kernel->xla_input_shapes.size()); + std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size()); + + // Pass remaining parameters. + for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { + int arg_num = kernel->xla_input_shapes[i].first; + const xla::Shape& shape = kernel->xla_input_shapes[i].second; + gpu::DeviceMemoryBase dmem( + const_cast<char*>(ctx->input(arg_num).tensor_data().data()), + ctx->input(arg_num).tensor_data().size()); + + arg_buffers[i] = + xla::ShapedBuffer::MakeArrayShapedBuffer( + shape, client->platform(), client->default_device_ordinal(), dmem) + .ConsumeValueOrDie(); + arg_ptrs[i] = arg_buffers[i].get(); + } + + // Make the final parameter point at local_runtime_context. + if (kernel->requires_runtime_context) { + gpu::DeviceMemoryBase local_runtime_context_dmem( + &local_runtime_context, sizeof(local_runtime_context)); + arg_buffers.push_back( + xla::ShapedBuffer::MakeArrayShapedBuffer( + xla::ShapeUtil::MakeOpaqueShape(), client->platform(), + client->default_device_ordinal(), local_runtime_context_dmem) + .ConsumeValueOrDie()); + arg_ptrs.push_back(arg_buffers.back().get()); + } + + // Execute the computation. + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(&xla_allocator); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + auto run_result = executable->Run(arg_ptrs, run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + if (local_runtime_context.error) { + ctx->CtxFailure(errors::InvalidArgument( + "Compiled kernel returned error: ", local_runtime_context.error_msg)); + return; + } + + output = std::move(run_result.ValueOrDie()); + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; + + // Computation output should always be a tuple. + if (VLOG_IS_ON(2)) { + VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); + } + output_is_tuple = xla::ShapeUtil::IsTuple(output->shape()); + } + CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); + + // Copy XLA results to the OpOutputList. + int output_num = 0; + for (int i = 0; i < ctx->num_outputs(); ++i) { + if (kernel->outputs[i].is_constant) { + // Output is a constant + const Tensor& const_tensor = kernel->outputs[i].constant_value; + const size_t total_bytes = const_tensor.TotalBytes(); + if (stream && total_bytes > 0) { + // Copy host -> device. (Empty tensors don't have backing buffers.) + VLOG(1) << "Constant output tensor on device"; + Tensor* output_tensor; + TF_CHECK_OK( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + + const void* src_ptr = DMAHelper::base(&const_tensor); + void* dst_ptr = DMAHelper::base(output_tensor); + gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes); + stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes); + } else { + // No copy required. + ctx->set_output(i, const_tensor); + } + } else { + const TensorShape& shape = kernel->outputs[i].shape; + VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); + + gpu::DeviceMemoryBase buffer; + if (output_is_tuple) { + buffer = output->buffer({output_num}); + } else { + CHECK_EQ(0, output_num); + buffer = output->buffer({}); + } + Tensor output_tensor; + // Looks up the owning Tensor by buffer address. + OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer( + buffer, ctx->expected_output_dtype(i), shape, + &output_tensor)); + ctx->set_output(i, output_tensor); + ++output_num; + } + + if (VLOG_IS_ON(3)) { + VLOG(3) << ctx->mutable_output(i)->DebugString(); + } + } + + VLOG(1) << "Done"; +} + +XlaLocalLaunchOp::~XlaLocalLaunchOp() { + VLOG(1) << "XlaLocalLaunchOp destroyed"; +} + +REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), + XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER( + Name("_XlaLaunch").Device(DEVICE_GPU).HostMemory("constants"), + XlaLocalLaunchOp); + +} // namespace tensorflow |