aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-08 08:26:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 08:31:49 -0800
commita47cd30d960b128e5ed405cb36e914aa36fe462a (patch)
treefa019ded87328d5d3bf4a0ccc3b8721571006437
parent16a6666c1c1a3f4b288472c4f461b6418bda0170 (diff)
This creates a new helper, xla_launch_util, that contains the business logic
of launching an XLA computation. Also changes the resource variable container from a std::vector<OptionalTensor> to a std::map<int, OptionalTensor> in preparation for backends where the resource variables aren't ordered densely at the end of the argument list. PiperOrigin-RevId: 188335574
-rw-r--r--tensorflow/compiler/jit/BUILD23
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD1
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc250
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h8
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc23
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h4
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc255
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h116
8 files changed, 418 insertions, 262 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 955d12dc20..c4a2d4ab03 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -135,6 +135,7 @@ cc_library(
deps = [
":common",
":jit_compilation_passes",
+ ":xla_launch_util",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
@@ -175,6 +176,28 @@ cc_library(
)
cc_library(
+ name = "xla_launch_util",
+ srcs = ["xla_launch_util.cc"],
+ hdrs = ["xla_launch_util.h"],
+ deps = [
+ ":common",
+ ":xla_compilation_cache",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/kernels:variable_ops",
+ ],
+)
+
+cc_library(
name = "xla_compilation_cache",
srcs = ["xla_compilation_cache.cc"],
hdrs = ["xla_compilation_cache.h"],
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 9bea566331..616a7f8f15 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -14,6 +14,7 @@ cc_library(
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
"//tensorflow/compiler/jit:xla_device",
+ "//tensorflow/compiler/jit:xla_launch_util",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 6353149e4a..cd7f8dd779 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -40,111 +41,6 @@ namespace gpu = perftools::gputools;
namespace tensorflow {
-// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
-// Assumes that the Tensorflow allocator permits asynchronous deallocation:
-// see comment on `AllowsAsynchronousDeallocation()`.
-class XlaAllocator : public xla::DeviceMemoryAllocator {
- public:
- XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context);
- ~XlaAllocator() override;
- xla::StatusOr<gpu::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
- bool retry_on_failure) override;
- Status Deallocate(int device_ordinal, gpu::DeviceMemoryBase* mem) override;
-
- // Register an Tensor (input or resource variable) with the allocator. If
- // the operation returns an alias to one of its inputs, then the allocator
- // needs to be able to handle it.
- Status RegisterArgument(const Tensor* t);
-
- // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
- // interpreted as having data type 'dtype' and shape 'shape'.
- Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
- const TensorShape& shape, Tensor* tensor) const;
-
- // The Tensorflow BFC allocator used on GPU allows host-side deallocation
- // before GPU execution takes place. Tensorflow uses the ordering of the main
- // compute stream to enforce a happens-before relationship between a memory
- // allocation and code that reuses the same memory. If Tensorflow adds
- // support for multiple GPU streams or allocators with different ordering
- // requirements, this code may need to change.
- // (This attribute has no effect on CPU.)
- bool AllowsAsynchronousDeallocation() const override { return true; }
-
- private:
- OpKernelContext* const op_context_;
-
- // Map from pointer address to the owning Tensor; used by
- // MakeTensorFromBuffer. Also used to automatically release Tensors when the
- // allocator is freed.
- std::unordered_map<void*, Tensor> tensors_;
-};
-
-XlaAllocator::XlaAllocator(const gpu::Platform* platform,
- OpKernelContext* op_context)
- : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
-
-XlaAllocator::~XlaAllocator() = default;
-
-xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
- int device_ordinal, uint64 size, bool retry_on_failure) {
- AllocatorAttributes allocator_attrs;
- allocator_attrs.set_on_host(false);
-
- AllocationAttributes allocation_attrs;
- allocation_attrs.no_retry_on_failure = !retry_on_failure;
-
- Tensor t;
- Status status = op_context_->allocate_temp(
- DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
- allocation_attrs);
- if (!status.ok()) {
- VLOG(2) << "Allocation failed " << size;
- return status;
- }
- void* data =
- reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
- tensors_[data] = t;
- return gpu::DeviceMemoryBase(data, size);
-}
-
-Status XlaAllocator::RegisterArgument(const Tensor* t) {
- void* data =
- reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
- tensors_[data] = *t;
- return Status::OK();
-}
-
-Status XlaAllocator::Deallocate(int device_ordinal,
- gpu::DeviceMemoryBase* mem) {
- if (mem->opaque() != nullptr) {
- if (tensors_.erase(mem->opaque()) == 0) {
- return tensorflow::errors::InvalidArgument("Unknown tensor address");
- }
- }
- return Status::OK();
-}
-
-Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
- DataType dtype,
- const TensorShape& shape,
- Tensor* out_tensor) const {
- void* ptr = const_cast<void*>(buffer.opaque());
- auto it = tensors_.find(ptr);
- if (it == tensors_.end()) {
- return errors::InvalidArgument("Unknown tensor address");
- }
- const Tensor& tensor = it->second;
-
- int64 output_size = DataTypeSize(dtype) * shape.num_elements();
- if (tensor.TotalBytes() == output_size) {
- out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
- } else {
- Tensor slice = tensor.Slice(0, output_size);
- out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
- }
- return Status::OK();
-}
-
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
: OpKernel(ctx), device_type_(ctx->device_type()) {
const NameAttrList* func;
@@ -196,23 +92,6 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
-std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables) {
- std::vector<OptionalTensor> snapshot(num_variables);
- int first_variable = ctx->num_inputs() - num_variables;
- for (int i = 0; i < num_variables; ++i) {
- Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
- if (LookupResource(ctx, handle, &variable).ok()) {
- tf_shared_lock lock(*variable->mu());
- snapshot[i].name = handle.name();
- snapshot[i].present = true;
- snapshot[i].value = *variable->tensor();
- }
- }
- return snapshot;
-}
-
void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "XlaLocalLaunchOp::Compute "
<< Canonicalize(function_.name(), AttrSlice(&function_.attr()));
@@ -244,7 +123,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
}
}
- std::vector<OptionalTensor> variables =
+ std::map<int, OptionalTensor> variables =
SnapshotResourceVariables(ctx, num_resource_args_);
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
@@ -269,43 +148,9 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
- std::unique_ptr<xla::ShapedBuffer> output;
- // Build xla::ShapedBuffers that point directly to the Tensor buffers.
- std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
- arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
- arg_buffers.resize(kernel->xla_input_shapes.size());
- std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size());
-
- const int first_variable_arg = ctx->num_inputs() - num_resource_args_;
- // Pass remaining parameters.
- const Tensor* t;
- for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
- int arg_num = kernel->input_mapping[i];
- const xla::Shape& shape = kernel->xla_input_shapes[i];
- if (arg_num >= first_variable_arg) {
- t = &(variables[arg_num - first_variable_arg].value);
- } else {
- t = &(ctx->input(arg_num));
- }
-
- gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
- const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
-
- const xla::Shape on_device_shape =
- client->backend().transfer_manager()->HostShapeToDeviceShape(shape);
- CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
- << "On-device shape "
- << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
- << " not the same as on-host shape "
- << xla::ShapeUtil::HumanStringWithLayout(shape);
- arg_buffers[i] = xla::MakeUnique<xla::ShapedBuffer>(
- /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(),
- client->default_device_ordinal());
- arg_buffers[i]->set_buffer(dmem, /*index=*/{});
- arg_ptrs[i] = arg_buffers[i].get();
-
- OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t));
- }
+ XlaComputationLaunchContext launch_context(num_resource_args_, client,
+ &xla_allocator);
+ launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
VLOG(2) << "Executing computation.";
@@ -315,93 +160,14 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
Env* env = Env::Default();
auto start_time = env->NowMicros();
- auto run_result = executable->Run(arg_ptrs, run_options);
+ auto run_result = executable->Run(launch_context.arguments(), run_options);
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
- output = run_result.ConsumeValueOrDie()->release();
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
- // Computation output should always be a tuple.
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString();
- }
- CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
-
- // Copy XLA results to the OpOutputList.
- int output_num = 0;
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- if (kernel->outputs[i].is_constant) {
- // Output is a constant.
- const Tensor& const_tensor = kernel->outputs[i].constant_value;
- const size_t total_bytes = const_tensor.TotalBytes();
- if (stream && total_bytes > 0) {
- // Copy host -> device. (Empty tensors don't have backing buffers.)
- VLOG(1) << "Constant output tensor on device";
- Tensor* output_tensor;
- TF_CHECK_OK(
- ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
-
- const void* src_ptr = DMAHelper::base(&const_tensor);
- void* dst_ptr = DMAHelper::base(output_tensor);
- gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
- stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
- } else {
- // No copy required.
- ctx->set_output(i, const_tensor);
- }
- } else {
- const TensorShape& shape = kernel->outputs[i].shape;
- VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
-
- gpu::DeviceMemoryBase buffer = output->buffer({output_num});
- Tensor output_tensor;
- // Looks up the owning Tensor by buffer address.
- OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
- buffer, ctx->expected_output_dtype(i), shape,
- &output_tensor));
- ctx->set_output(i, output_tensor);
- ++output_num;
- }
-
- if (VLOG_IS_ON(3)) {
- VLOG(3) << ctx->mutable_output(i)->DebugString();
- }
- }
-
- // Apply variable updates, if any.
- VLOG(2) << "Applying variable updates";
- for (int i = 0; i < kernel->resource_updates.size(); ++i) {
- const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- OP_REQUIRES(ctx,
- write.input_index >= 0 && write.input_index < ctx->num_inputs(),
- errors::Internal("Invalid input index for variable write."));
-
- gpu::DeviceMemoryBase buffer = output->buffer({output_num});
-
- Var* variable = nullptr;
- // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not
- // a Tensor.
- OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index),
- &variable, [this, ctx, &write](Var** ptr) {
- *ptr = new Var(write.type);
- return Status::OK();
- }));
-
- core::ScopedUnref s(variable);
-
- mutex_lock ml(*variable->mu());
- OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
- errors::Internal("Mismatched type in variable write"));
-
- // Looks up the owning Tensor by buffer address.
- OP_REQUIRES_OK(
- ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
- variable->tensor()));
- ++output_num;
- }
-
+ launch_context.PopulateOutputs(ctx, kernel,
+ run_result.ConsumeValueOrDie()->release());
VLOG(1) << "Done";
}
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index 47fd912b12..c6cc0986af 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -26,14 +26,6 @@ limitations under the License.
namespace tensorflow {
-// Takes a snapshot of the values of resource variable arguments, which are
-// the last `num_variables` arguments. We snapshot tensors that back
-// resource variables since concurrent updates may modify the shape, and it is
-// important that the shapes used for compilation match the true shapes of the
-// buffers.
-std::vector<OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
- int num_variables);
-
// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
// responsible for handling interactions with the TensorFlow executor.
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 6d854a920e..8cc79a9bd0 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -93,7 +93,7 @@ uint64 XlaCompilationCache::Signature::Hash::operator()(
Status XlaCompilationCache::BuildSignature(
const NameAttrList& function, int num_constant_args,
- const std::vector<OptionalTensor>& variable_args, OpKernelContext* ctx,
+ const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx,
Signature* signature) {
signature->name = Canonicalize(function.name(), AttrSlice(&function.attr()));
signature->arg_values.resize(num_constant_args);
@@ -115,7 +115,8 @@ Status XlaCompilationCache::BuildSignature(
}
// For variable signatures, use the type and shape of the variable's
// current value.
- for (const OptionalTensor& variable : variable_args) {
+ for (auto& iterator : variable_args) {
+ const OptionalTensor& variable = iterator.second;
TF_RET_CHECK(input_num < ctx->num_inputs());
if (variable.present) {
signature->arg_types.emplace_back(variable.value.dtype(),
@@ -133,7 +134,7 @@ namespace {
// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
// op. The first `num_constant_args` arguments must be host-memory Tensors.
Status BuildArguments(int num_constant_args,
- const std::vector<OptionalTensor>& variable_args,
+ const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,
std::vector<XlaCompiler::Argument>* args) {
args->resize(ctx->num_inputs());
@@ -175,17 +176,17 @@ Status BuildArguments(int num_constant_args,
// Handles resource variables.
TF_RET_CHECK(input_num + num_variable_args == ctx->num_inputs());
- for (int variable_id = 0; variable_id < num_variable_args; ++variable_id) {
+ for (auto& iterator : variable_args) {
const Tensor& input = ctx->input(input_num);
TF_RET_CHECK(input.dtype() == DT_RESOURCE);
XlaCompiler::Argument& arg = (*args)[input_num];
- arg.name = variable_args[variable_id].name;
+ arg.name = iterator.second.name;
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
- if (variable_args[variable_id].present) {
- const Tensor& value = variable_args[variable_id].value;
+ if (iterator.second.present) {
+ const Tensor& value = iterator.second.value;
arg.type = value.dtype();
arg.shape = value.shape();
arg.initialized = true;
@@ -233,7 +234,7 @@ Status XlaCompilationCache::BuildExecutable(
Status XlaCompilationCache::Compile(
const XlaCompiler::Options& options, const NameAttrList& function,
- int num_constant_args, const std::vector<OptionalTensor>& variable_args,
+ int num_constant_args, const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
@@ -250,10 +251,12 @@ Status XlaCompilationCache::Compile(
<< " present=" << ctx->has_input(i)
<< " shape=" << shape.DebugString();
}
- for (const OptionalTensor& variable : variable_args) {
+ for (auto& iterator : variable_args) {
+ const OptionalTensor& variable = iterator.second;
VLOG(2) << "variable present=" << variable.present
<< " type=" << DataTypeString(variable.value.dtype())
- << " shape=" << variable.value.shape().DebugString();
+ << " shape=" << variable.value.shape().DebugString()
+ << " TF arg= " << iterator.first;
}
VLOG(2) << "num_outputs = " << ctx->num_outputs();
for (int i = 0; i < ctx->num_outputs(); i++) {
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 0858020716..d506378314 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -63,7 +63,7 @@ class XlaCompilationCache : public ResourceBase {
// outputs.
Status Compile(const XlaCompiler::Options& options,
const NameAttrList& function, int num_constant_args,
- const std::vector<OptionalTensor>& variable_args,
+ const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable,
@@ -105,7 +105,7 @@ class XlaCompilationCache : public ResourceBase {
// Builds the signature for a compilation.
Status BuildSignature(const NameAttrList& function, int num_constant_args,
- const std::vector<OptionalTensor>& variable_args,
+ const std::map<int, OptionalTensor>& variable_args,
OpKernelContext* ctx, Signature* signature);
// The value associated with a cache entry.
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
new file mode 100644
index 0000000000..8322dd2e82
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -0,0 +1,255 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace gpu = perftools::gputools;
+
+namespace tensorflow {
+
+std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
+ int num_variables) {
+ std::map<int, OptionalTensor> snapshot;
+ int first_variable = ctx->num_inputs() - num_variables;
+ for (int i = 0; i < num_variables; ++i) {
+ Var* variable = nullptr;
+ ResourceHandle handle = HandleFromInput(ctx, first_variable + i);
+ OptionalTensor& tensor = snapshot[first_variable + i];
+ if (LookupResource(ctx, handle, &variable).ok()) {
+ tf_shared_lock lock(*variable->mu());
+ tensor.name = handle.name();
+ tensor.present = true;
+ tensor.value = *variable->tensor();
+ }
+ }
+ return snapshot;
+}
+
+XlaAllocator::XlaAllocator(const gpu::Platform* platform,
+ OpKernelContext* op_context)
+ : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
+
+XlaAllocator::~XlaAllocator() = default;
+
+xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) {
+ AllocatorAttributes allocator_attrs;
+ allocator_attrs.set_on_host(false);
+
+ AllocationAttributes allocation_attrs;
+ allocation_attrs.no_retry_on_failure = !retry_on_failure;
+
+ Tensor t;
+ Status status = op_context_->allocate_temp(
+ DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
+ allocation_attrs);
+ if (!status.ok()) {
+ VLOG(2) << "Allocation failed " << size;
+ return status;
+ }
+ void* data =
+ reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
+ tensors_[data] = t;
+ return gpu::DeviceMemoryBase(data, size);
+}
+
+Status XlaAllocator::RegisterArgument(const Tensor* t) {
+ void* data =
+ reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
+ tensors_[data] = *t;
+ return Status::OK();
+}
+
+Status XlaAllocator::Deallocate(int device_ordinal,
+ gpu::DeviceMemoryBase* mem) {
+ if (mem->opaque() != nullptr) {
+ if (tensors_.erase(mem->opaque()) == 0) {
+ return tensorflow::errors::InvalidArgument("Unknown tensor address");
+ }
+ }
+ return Status::OK();
+}
+
+Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
+ DataType dtype,
+ const TensorShape& shape,
+ Tensor* out_tensor) const {
+ void* ptr = const_cast<void*>(buffer.opaque());
+ auto it = tensors_.find(ptr);
+ if (it == tensors_.end()) {
+ return errors::InvalidArgument("Unknown tensor address");
+ }
+ const Tensor& tensor = it->second;
+
+ int64 output_size = DataTypeSize(dtype) * shape.num_elements();
+ if (tensor.TotalBytes() == output_size) {
+ out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
+ } else {
+ Tensor slice = tensor.Slice(0, output_size);
+ out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
+ }
+ return Status::OK();
+}
+
+XlaComputationLaunchContext::XlaComputationLaunchContext(
+ int64 num_resource_args, xla::LocalClient* client,
+ XlaAllocator* xla_allocator)
+ : num_resource_args_(num_resource_args),
+ client_(client),
+ xla_allocator_(xla_allocator) {}
+
+void XlaComputationLaunchContext::PopulateInputs(
+ OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
+ const std::map<int, OptionalTensor>& variables) {
+ // Build xla::ShapedBuffers that point directly to the Tensor buffers.
+ arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
+ arg_buffers_.resize(kernel->xla_input_shapes.size());
+ arg_ptrs_ = std::vector<xla::ShapedBuffer*>(arg_buffers_.size());
+
+ // Pass remaining parameters.
+ const Tensor* t;
+ for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
+ int arg_num = kernel->input_mapping[i];
+ const xla::Shape& shape = kernel->xla_input_shapes[i];
+ if (variables.count(arg_num)) {
+ t = &(variables.at(arg_num).value);
+ CHECK(t);
+ } else {
+ t = &(ctx->input(arg_num));
+ }
+
+ gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase(
+ const_cast<char*>(t->tensor_data().data()), t->tensor_data().size());
+
+ const xla::Shape on_device_shape =
+ client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
+ CHECK(xla::ShapeUtil::Equal(shape, on_device_shape))
+ << "On-device shape "
+ << xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
+ << " not the same as on-host shape "
+ << xla::ShapeUtil::HumanStringWithLayout(shape);
+ arg_buffers_[i] = xla::MakeUnique<xla::ShapedBuffer>(
+ /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(),
+ client_->default_device_ordinal());
+ arg_buffers_[i]->set_buffer(dmem, /*index=*/{});
+ arg_ptrs_[i] = arg_buffers_[i].get();
+
+ OP_REQUIRES_OK(ctx, xla_allocator_->RegisterArgument(t));
+ }
+}
+
+void XlaComputationLaunchContext::PopulateOutputs(
+ OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
+ std::unique_ptr<xla::ShapedBuffer> output) {
+ gpu::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+ // Computation output should always be a tuple.
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString();
+ }
+ CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
+
+ // Copy XLA results to the OpOutputList.
+ int output_num = 0;
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ if (kernel->outputs[i].is_constant) {
+ // Output is a constant.
+ const Tensor& const_tensor = kernel->outputs[i].constant_value;
+ const size_t total_bytes = const_tensor.TotalBytes();
+ if (stream && total_bytes > 0) {
+ // Copy host -> device. (Empty tensors don't have backing buffers.)
+ VLOG(1) << "Constant output tensor on device";
+ Tensor* output_tensor;
+ TF_CHECK_OK(
+ ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+
+ const void* src_ptr = DMAHelper::base(&const_tensor);
+ void* dst_ptr = DMAHelper::base(output_tensor);
+ gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+ stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
+ } else {
+ // No copy required.
+ ctx->set_output(i, const_tensor);
+ }
+ } else {
+ const TensorShape& shape = kernel->outputs[i].shape;
+ VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
+
+ gpu::DeviceMemoryBase buffer = output->buffer({output_num});
+ Tensor output_tensor;
+ // Looks up the owning Tensor by buffer address.
+ OP_REQUIRES_OK(ctx, xla_allocator_->MakeTensorFromBuffer(
+ buffer, ctx->expected_output_dtype(i), shape,
+ &output_tensor));
+ ctx->set_output(i, output_tensor);
+ ++output_num;
+ }
+
+ if (VLOG_IS_ON(3)) {
+ VLOG(3) << ctx->mutable_output(i)->DebugString();
+ }
+ }
+
+ // Apply variable updates, if any.
+ VLOG(2) << "Applying variable updates";
+ for (int i = 0; i < kernel->resource_updates.size(); ++i) {
+ const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
+ OP_REQUIRES(ctx,
+ write.input_index >= 0 && write.input_index < ctx->num_inputs(),
+ errors::Internal("Invalid input index for variable write."));
+
+ gpu::DeviceMemoryBase buffer = output->buffer({output_num});
+
+ Var* variable = nullptr;
+ // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
+ // not a Tensor.
+ OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, write.input_index),
+ &variable, [this, ctx, &write](Var** ptr) {
+ *ptr = new Var(write.type);
+ return Status::OK();
+ }));
+
+ core::ScopedUnref s(variable);
+
+ mutex_lock ml(*variable->mu());
+ OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
+ errors::Internal("Mismatched type in variable write"));
+
+ // Looks up the owning Tensor by buffer address.
+ OP_REQUIRES_OK(ctx,
+ xla_allocator_->MakeTensorFromBuffer(
+ buffer, write.type, write.shape, variable->tensor()));
+ ++output_num;
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
new file mode 100644
index 0000000000..9fd356fce5
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Contains utilities for launching compiled XLA kernels for a KernelContext.
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_LAUNCH_UTIL_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+class XlaAllocator;
+
+// Takes a snapshot of the values of resource variable arguments, which are
+// the last `num_variables` arguments. We snapshot tensors that back
+// resource variables since concurrent updates may modify the shape, and it is
+// important that the shapes used for compilation match the true shapes of the
+// buffers.
+//
+// Returns a map of TensorFlow argument index to resource variable.
+std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
+ int num_variables);
+
+// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
+// Assumes that the Tensorflow allocator permits asynchronous deallocation:
+// see comment on `AllowsAsynchronousDeallocation()`.
+class XlaAllocator : public xla::DeviceMemoryAllocator {
+ public:
+ XlaAllocator(const perftools::gputools::Platform* platform,
+ OpKernelContext* op_context);
+ ~XlaAllocator() override;
+ xla::StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) override;
+ Status Deallocate(int device_ordinal,
+ perftools::gputools::DeviceMemoryBase* mem) override;
+
+ // Register an Tensor (input or resource variable) with the allocator. If
+ // the operation returns an alias to one of its inputs, then the allocator
+ // needs to be able to handle it.
+ Status RegisterArgument(const Tensor* t);
+
+ // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
+ // interpreted as having data type 'dtype' and shape 'shape'.
+ Status MakeTensorFromBuffer(perftools::gputools::DeviceMemoryBase buffer,
+ DataType dtype, const TensorShape& shape,
+ Tensor* out_tensor) const;
+
+ // The Tensorflow BFC allocator used on GPU allows host-side deallocation
+ // before GPU execution takes place. Tensorflow uses the ordering of the main
+ // compute stream to enforce a happens-before relationship between a memory
+ // allocation and code that reuses the same memory. If Tensorflow adds
+ // support for multiple GPU streams or allocators with different ordering
+ // requirements, this code may need to change.
+ // (This attribute has no effect on CPU.)
+ bool AllowsAsynchronousDeallocation() const override { return true; }
+
+ private:
+ OpKernelContext* const op_context_;
+
+ // Map from pointer address to the owning Tensor; used by
+ // MakeTensorFromBuffer. Also used to automatically release Tensors when the
+ // allocator is freed.
+ std::unordered_map<void*, Tensor> tensors_;
+};
+
+// Helper class to perform the marshalling of TensorFlow inputs and outputs to
+// ShapedBuffers suitable for passing to an XLA computation.
+class XlaComputationLaunchContext {
+ public:
+ XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client,
+ XlaAllocator* xla_allocator);
+
+ // Add all inputs within `ctx` as XLA arguments (returned by arguments()).
+ // `variables` is a map from TensorFlow argument number to resource variable.
+ void PopulateInputs(OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult* kernel,
+ const std::map<int, OptionalTensor>& variables);
+
+ // Given the XLA output in `output`, populate all outputs of `ctx`.
+ void PopulateOutputs(OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult* kernel,
+ std::unique_ptr<xla::ShapedBuffer> output);
+
+ // Return the argument list. Only valid after PopulateInputs() has been
+ // called.
+ const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
+
+ private:
+ int64 num_resource_args_;
+ xla::LocalClient* client_;
+ XlaAllocator* xla_allocator_;
+ std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
+ std::vector<xla::ShapedBuffer*> arg_ptrs_;
+};
+
+} // namespace tensorflow
+
+#endif