aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xrt
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 16:52:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 17:00:58 -0700
commit692e3fb36560b1cee3de10d48b758e9273da71a9 (patch)
treec8a0b8aff4cca9b60ee8761d60bf5ae7caf6fe8b /tensorflow/compiler/xrt
parent2391e5204b70de5275a0180ca6c9ae5396cec723 (diff)
First checkin of XRT primitives to call XLA computations directly via
the TensorFlow infrastructure. These primitives will allow direct access to XLA via Cloud TPU, and also provide an easy path to integrate custom XLA computations with other (distributed) TensorFlow computations on CPU and GPU. The API is *very experimental* and subject to change, perhaps substantially in the short term. PiperOrigin-RevId: 211006371
Diffstat (limited to 'tensorflow/compiler/xrt')
-rw-r--r--tensorflow/compiler/xrt/BUILD83
-rw-r--r--tensorflow/compiler/xrt/cc/BUILD20
-rw-r--r--tensorflow/compiler/xrt/kernels/BUILD72
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc239
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_execute_op.cc254
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.cc110
-rw-r--r--tensorflow/compiler/xrt/kernels/xrt_state_ops.h424
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_compile_ops.cc48
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_execute_op.cc44
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_state_ops.cc122
-rw-r--r--tensorflow/compiler/xrt/tests/BUILD65
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc421
-rw-r--r--tensorflow/compiler/xrt/xrt.proto78
-rw-r--r--tensorflow/compiler/xrt/xrt_compilation_cache.cc263
-rw-r--r--tensorflow/compiler/xrt/xrt_compilation_cache.h238
-rw-r--r--tensorflow/compiler/xrt/xrt_device.cc46
-rw-r--r--tensorflow/compiler/xrt/xrt_device.h66
-rw-r--r--tensorflow/compiler/xrt/xrt_state.cc458
-rw-r--r--tensorflow/compiler/xrt/xrt_state.h208
19 files changed, 3259 insertions, 0 deletions
diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD
new file mode 100644
index 0000000000..efbe980278
--- /dev/null
+++ b/tensorflow/compiler/xrt/BUILD
@@ -0,0 +1,83 @@
+# Description: Operations defined for XRT
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler/xrt:__subpackages__",
+ ],
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_libs",
+)
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
+xla_proto_library(
+ name = "xrt_proto",
+ srcs = ["xrt.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ ],
+)
+
+cc_library(
+ name = "xrt_utils",
+ srcs = [
+ "xrt_compilation_cache.cc",
+ "xrt_device.cc",
+ "xrt_state.cc",
+ ],
+ hdrs = [
+ "xrt_compilation_cache.h",
+ "xrt_device.h",
+ "xrt_state.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/jit:xla_device",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/stream_executor",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "xrt_compile_ops",
+ "xrt_state_ops",
+ "xrt_execute_op",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "xrt_server",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":xrt_compile_ops_op_lib",
+ ":xrt_execute_op_op_lib",
+ ":xrt_state_ops_op_lib",
+ "//tensorflow/compiler/xrt/kernels:xrt_ops",
+ ],
+)
diff --git a/tensorflow/compiler/xrt/cc/BUILD b/tensorflow/compiler/xrt/cc/BUILD
new file mode 100644
index 0000000000..5c1e86b76b
--- /dev/null
+++ b/tensorflow/compiler/xrt/cc/BUILD
@@ -0,0 +1,20 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrappers_cc",
+)
+
+tf_gen_op_wrappers_cc(
+ name = "xrt_ops",
+ op_lib_names = [
+ "xrt_compile_ops",
+ "xrt_state_ops",
+ "xrt_execute_op",
+ ],
+ pkg = "//tensorflow/compiler/xrt",
+)
diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD
new file mode 100644
index 0000000000..68ba17a424
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/BUILD
@@ -0,0 +1,72 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler/xrt:__subpackages__",
+ ],
+)
+
+cc_library(
+ name = "xrt_state_ops",
+ hdrs = ["xrt_state_ops.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xrt_ops",
+ srcs = [
+ "xrt_compile_ops.cc",
+ "xrt_execute_op.cc",
+ "xrt_state_ops.cc",
+ ],
+ deps = [
+ ":xrt_state_ops",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:compile_only_client",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_placer",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/stream_executor:stream_executor_headers_lib",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
new file mode 100644
index 0000000000..5cf2bc8861
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
@@ -0,0 +1,239 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for compiling XLA computations and managing handles that refer to
+// them.
+
+#include <cstdlib>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/fingerprint.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+const int kDefaultCacheSize = 100;
+
+class XRTCompileOp : public OpKernel {
+ public:
+ explicit XRTCompileOp(OpKernelConstruction* ctx);
+ ~XRTCompileOp() override;
+ XRTCompileOp(const XRTCompileOp&) = delete;
+ XRTCompileOp& operator=(const XRTCompileOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ Status Compile(OpKernelContext* ctx,
+ const xrt::XLAComputation& computation_proto,
+ std::unique_ptr<xla::LocalExecutable>* program);
+};
+
+Status CompilationCacheKey(const xrt::XLAComputation& computation,
+ string* key) {
+ string serialized;
+ TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized));
+ uint64 fingerprint = Fingerprint64(serialized);
+ *key = strings::StrCat(fingerprint);
+ return Status::OK();
+}
+
+XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+Status XRTCompileOp::Compile(OpKernelContext* ctx,
+ const xrt::XLAComputation& computation_proto,
+ std::unique_ptr<xla::LocalExecutable>* program) {
+ const xrt::XLAComputationConfig& config = computation_proto.config();
+
+ // The default config value is 0; treat it as 1 for convenience.
+ int num_replicas = config.num_replicas() ? config.num_replicas() : 1;
+ TF_RET_CHECK(num_replicas == 1);
+ int num_cores_per_replica =
+ config.num_cores_per_replica() ? config.num_cores_per_replica() : 1;
+ TF_RET_CHECK(num_cores_per_replica == 1);
+ TF_RET_CHECK(config.per_core_program_shape_size() == 0);
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class XRTGenericDeviceAccessor::ScopedRef device_ref;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::InitScopedRef(ctx, 0, &device_ref));
+
+ xla::LocalClient* client = device_ref.client();
+
+ // There is officially no way to use XLA in a client/server architecture where
+ // client and server are built from different revisions, because the XLA team
+ // does not want to give any guarantees about the stability of the Hlo
+ // proto. For cloud TPU this is fine because server and client versions can be
+ // assumed to be synced to the same version. For general use the mechanism
+ // here (using a snapshot from XlaComputation) works as well as the "official"
+ // XLA client/server design, which serializes the same proto between client
+ // and server, so in reality is probably fine.
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation computation,
+ client->LoadSnapshot(computation_proto.hlo_snapshot()));
+
+ std::vector<const xla::Shape*> argument_layouts(
+ config.program_shape().parameters_size());
+ for (int i = 0; i < config.program_shape().parameters_size(); ++i) {
+ argument_layouts[i] = &config.program_shape().parameters(i);
+ }
+ xla::ExecutableBuildOptions build_options;
+ build_options.set_device_ordinal(client->default_device_ordinal());
+ build_options.set_result_layout(config.program_shape().result());
+ build_options.set_device_allocator(device_ref.backend()->memory_allocator());
+
+ VLOG(1) << "Building executable";
+ auto compile_result =
+ client->Compile(computation, argument_layouts, build_options);
+ if (!compile_result.ok()) {
+ return compile_result.status();
+ }
+ *program = std::move(compile_result.ValueOrDie());
+ return Status::OK();
+}
+
+void XRTCompileOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XRTCompileOp::Compute";
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
+
+ const Tensor& computation_input = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
+ errors::Internal("computation input should be a string scalar"));
+
+ xrt::XLAComputation computation_proto;
+ OP_REQUIRES(
+ ctx,
+ computation_proto.ParseFromString(computation_input.scalar<string>()()),
+ errors::InvalidArgument(
+ "Unable to parse computation input to XLAComputation"));
+
+ string key;
+ OP_REQUIRES_OK(ctx, CompilationCacheKey(computation_proto, &key));
+
+ // Process-wide cache of XLA executables.
+ XRTCompilationCache* cache;
+ OP_REQUIRES_OK(ctx,
+ rm->LookupOrCreate<XRTCompilationCache>(
+ rm->default_container(), kXRTCompilationCacheResourceName,
+ &cache, [](XRTCompilationCache** new_cache) {
+ *new_cache = new XRTCompilationCache(kDefaultCacheSize);
+ return Status::OK();
+ }));
+ core::ScopedUnref cache_unref(cache);
+
+ int64 uid;
+ OP_REQUIRES_OK(
+ ctx, cache->CompileIfKeyAbsent(
+ key, &uid, [&](std::unique_ptr<xla::LocalExecutable>* program) {
+ VLOG(1) << "Compiling XLA executable";
+ return Compile(ctx, computation_proto, program);
+ }));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = uid;
+ ctx->set_output(0, output);
+}
+
+XRTCompileOp::~XRTCompileOp() = default;
+
+class XRTReleaseCompilationRefOp : public OpKernel {
+ public:
+ explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx);
+ ~XRTReleaseCompilationRefOp() override;
+ XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete;
+ XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) =
+ delete;
+
+ void Compute(OpKernelContext* ctx) override;
+};
+
+XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp(
+ OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default;
+
+void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
+
+ const Tensor& key_tensor = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(key_tensor.shape()),
+ errors::Internal("computation key should be a string scalar"));
+ int64 uid = key_tensor.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, XRTGenericDeviceAccessor::GetResourceManager(ctx, &rm));
+
+ // Process-wide cache of XLA executables.
+ XRTCompilationCache* cache;
+ OP_REQUIRES_OK(ctx, rm->Lookup<XRTCompilationCache>(
+ rm->default_container(),
+ kXRTCompilationCacheResourceName, &cache));
+ core::ScopedUnref cache_unref(cache);
+
+ OP_REQUIRES_OK(ctx, cache->Release(uid));
+
+ VLOG(2) << "Released computation handle " << uid;
+}
+
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("XRTCompile")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("computation")
+ .HostMemory("handle"),
+ XRTCompileOp);
+REGISTER_KERNEL_BUILDER(Name("XRTCompile")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("computation")
+ .HostMemory("handle"),
+ XRTCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle"),
+ XRTReleaseCompilationRefOp);
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle"),
+ XRTReleaseCompilationRefOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
new file mode 100644
index 0000000000..257b054f16
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
@@ -0,0 +1,254 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/service/computation_placer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/compiler/xrt/xrt_state.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+#include "tensorflow/stream_executor/stream_executor_internal.h"
+
+namespace tensorflow {
+
+namespace {
+
+uint32 InitialRandomSeed() {
+ // Support plumbing the TF seed through to XLA is being worked on.
+ // If a user wants deterministic behavior, their best option
+ // is to start with a known checkpoint. This also handles issues when
+ // multiple random calls can be invoked in any order by TF executor.
+ // Another option is to use stateless random ops. They have much cleaner
+ // semantics.
+ // If a user really wants to set a deterministic seed for XLA-based
+ // devices, this is the place to do it.
+ std::random_device rd;
+ // Make the starting value odd.
+ return rd() | 1;
+}
+
+uint32 GetXLARandomSeed() {
+ // We initialize counter with an odd number and increment it by two
+ // everytime. This ensures that it will never be zero, even
+ // after an overflow. When seeded with zero, some XLA backends
+ // can return all zeros instead of random numbers.
+ static std::atomic<uint32> counter(InitialRandomSeed());
+ return counter.fetch_add(2);
+}
+
+// Looks up the input `key` in the compilation cache.
+Status GetComputationCacheEntry(
+ XRTCompilationCache* cache, int64 key,
+ std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
+ TF_RETURN_IF_ERROR(cache->Lookup(key, entry));
+ return Status::OK();
+}
+
+// Populates `inputs` with the input tensors to the computation.
+Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
+ bool release_inputs,
+ std::vector<XRTTupleAllocation*>* input_tuples,
+ std::vector<xla::ShapedBuffer>* input_allocations,
+ std::vector<xla::ShapedBuffer*>* input_pointers) {
+ OpInputList arg_list;
+ TF_RETURN_IF_ERROR(context->input_list("input_handles", &arg_list));
+
+ input_tuples->resize(arg_list.size());
+ input_pointers->resize(arg_list.size());
+ for (int i = 0; i < arg_list.size(); ++i) {
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(arg_list[i].shape()));
+ int64 input_uid = arg_list[i].scalar<int64>()();
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::Lookup(rm, input_uid, &(*input_tuples)[i]));
+ if (release_inputs) {
+ // We are holding a reference to the tuple, so we can safely delete it
+ // from the resource manager here.
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::DeleteFromResourceManager(rm, input_uid));
+ VLOG(2) << "Released allocation handle " << input_uid;
+ }
+ XRTTupleAllocation* tuple = (*input_tuples)[i];
+ input_allocations->emplace_back(tuple->ToShapedBuffer());
+ }
+ for (int i = 0; i < arg_list.size(); ++i) {
+ (*input_pointers)[i] = &(*input_allocations)[i];
+ }
+ return Status::OK();
+}
+
+// XRTExecuteOp
+
+class XRTExecuteOp : public AsyncOpKernel {
+ public:
+ explicit XRTExecuteOp(OpKernelConstruction* context);
+ ~XRTExecuteOp() override;
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override;
+
+ private:
+ Status DoWork(OpKernelContext* context);
+};
+
+XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
+ // Schedule onto the default queue, for unbounded concurrency. See b/73520706
+ Env::Default()->SchedClosure([this, context, done]() {
+ OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
+ done();
+ });
+}
+
+Status XRTExecuteOp::DoWork(OpKernelContext* context) {
+ VLOG(1) << "XRTExecuteOp::Compute";
+ ResourceMgr* rm;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::GetResourceManager(context, &rm));
+
+ const Tensor& execution_input = context->input(0);
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
+ int64 compilation_handle = execution_input.scalar<int64>()();
+
+ const Tensor& execution_config = context->input(1);
+ TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
+ xrt::XRTExecutionConfig config_proto;
+ TF_RET_CHECK(
+ config_proto.ParseFromString(execution_config.scalar<string>()()));
+
+ int core_index_in_replica = config_proto.core_index_in_replica();
+ TF_RET_CHECK(core_index_in_replica == 0);
+ bool release_inputs = config_proto.release_input_handles();
+ bool release_compilation = config_proto.release_compilation_handle();
+
+ XRTCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->Lookup<XRTCompilationCache>(
+ rm->default_container(), kXRTCompilationCacheResourceName, &cache));
+ core::ScopedUnref cache_unref(cache);
+
+ std::unique_ptr<XRTCompilationCacheEntryRef> entry;
+ TF_RETURN_IF_ERROR(cache->Lookup(compilation_handle, &entry));
+
+ if (release_compilation) {
+ // Process-wide cache of XLA executables.
+ TF_RETURN_IF_ERROR(cache->Release(compilation_handle));
+ VLOG(2) << "Released compilation handle " << compilation_handle;
+ }
+
+ std::vector<XRTTupleAllocation*> input_tuples;
+ // Make a cleanup method so that we can safely return in error conditions
+ // without leaking references to allocations.
+ auto buffer_releaser = gtl::MakeCleanup([&input_tuples]() {
+ for (auto tuple : input_tuples) {
+ if (tuple != nullptr) {
+ tuple->Unref();
+ }
+ }
+ });
+ std::vector<xla::ShapedBuffer> input_allocations;
+ std::vector<xla::ShapedBuffer*> input_pointers;
+ TF_RETURN_IF_ERROR(GetComputationInputs(context, rm, release_inputs,
+ &input_tuples, &input_allocations,
+ &input_pointers));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class XRTGenericDeviceAccessor::ScopedRef device_ref;
+ TF_RETURN_IF_ERROR(
+ XRTGenericDeviceAccessor::InitScopedRef(context, 0, &device_ref));
+
+ int rng_seed = config_proto.rng_seed();
+ if (rng_seed == 0) {
+ rng_seed = GetXLARandomSeed();
+ }
+
+ se::Stream* stream = context->op_device_context()
+ ? context->op_device_context()->stream()
+ : nullptr;
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(device_ref.backend()->memory_allocator());
+ run_options.set_intra_op_thread_pool(&context->eigen_cpu_device());
+ run_options.set_rng_seed(rng_seed);
+
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ xla::LocalExecutable* executable = entry->get().get_executable();
+ auto run_result = executable->Run(input_pointers, run_options);
+ if (!run_result.ok()) {
+ return run_result.status();
+ }
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ auto scoped_buffer = run_result.ConsumeValueOrDie();
+ auto shaped_buffer = scoped_buffer.release();
+ XRTTupleAllocation* output_tuple;
+ TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
+ shaped_buffer, device_ref.backend(), device_ref.device_ordinal(),
+ &output_tuple));
+
+ Tensor* output_tensor;
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(0, TensorShape({}), &output_tensor));
+ int64 key;
+ TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key));
+ output_tensor->scalar<int64>()() = key;
+
+ return Status::OK();
+}
+
+XRTExecuteOp::~XRTExecuteOp() = default;
+
+} // namespace
+
+REGISTER_KERNEL_BUILDER(Name("XRTExecute")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("computation_handle")
+ .HostMemory("execution_config")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTExecuteOp);
+
+REGISTER_KERNEL_BUILDER(Name("XRTExecute")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("computation_handle")
+ .HostMemory("execution_config")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTExecuteOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
new file mode 100644
index 0000000000..ffea592491
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.cc
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xrt/kernels/xrt_state_ops.h"
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+
+namespace tensorflow {
+
+REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("allocation")
+ .HostMemory("handle"),
+ XRTAllocateOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTAllocate")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("allocation")
+ .HostMemory("handle"),
+ XRTAllocateOp<XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTSubTuple")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<false, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("base_handle")
+ .HostMemory("shape_index")
+ .HostMemory("output_handle"),
+ XRTSubTupleOp<true, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("tuple_description")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTMakeTupleOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("tuple_description")
+ .HostMemory("input_handles")
+ .HostMemory("output_handle"),
+ XRTMakeTupleOp<XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<false, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle")
+ .HostMemory("literal"),
+ XRTReadLiteralOp<true, XRTGenericDeviceAccessor>);
+
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
+ .Device(DEVICE_XLA_GPU)
+ .HostMemory("handle"),
+ XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
+REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle")
+ .Device(DEVICE_XLA_CPU)
+ .HostMemory("handle"),
+ XRTReleaseAllocationOp<XRTGenericDeviceAccessor>);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
new file mode 100644
index 0000000000..478c9663a7
--- /dev/null
+++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h
@@ -0,0 +1,424 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#ifndef TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
+#define TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/compiler/xrt/xrt_state.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Helper functions for templated ops.
+class XRTStateHelpers {
+ public:
+ // The Status return value allows us to use the
+ // TF_ASSIGN_OR_RETURN macro, which doesn't work within the body of an
+ // OpKernel::Compute method.
+ static Status MakeLiteral(const xla::LiteralProto& proto,
+ std::unique_ptr<xla::Literal>* literal) {
+ TF_ASSIGN_OR_RETURN(*literal, xla::Literal::CreateFromProto(proto));
+ return Status::OK();
+ }
+
+ // ParseTupleNode is the recursive function used to parse a recursive
+ // xrt::XLATupleNode proto and generate the xla::Shape of the 'spine' i.e. the
+ // tuple shape where every leaf is an existing allocation. As a side-effect it
+ // fills in input_vector by looking up allocations from handles in the
+ // input_tensor_list as they are referenced by nodes in the proto.
+ static Status ParseTupleNode(
+ const xrt::XLATupleNode& tuple_node, const OpInputList& input_tensor_list,
+ std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
+ xla::Shape* shape, ResourceMgr* rm) {
+ if (tuple_node.tuples_size() > 0) {
+ // This is an internal node in the proto so descend recursively.
+ xla::Shape dummy = xla::ShapeUtil::MakeShapeWithType<float>({});
+ std::vector<xla::Shape> subshapes(tuple_node.tuples_size(), dummy);
+ *xla::ShapeUtil::GetMutableSubshape(shape, {}) =
+ xla::ShapeUtil::MakeTupleShape(subshapes);
+ for (int i = 0; i < tuple_node.tuples_size(); ++i) {
+ TF_RETURN_IF_ERROR(ParseTupleNode(
+ tuple_node.tuples(i), input_tensor_list, input_vector,
+ xla::ShapeUtil::GetMutableSubshape(shape, {i}), rm));
+ }
+ } else {
+ // This is a leaf node in the proto so look up the referenced input.
+ int input_index = tuple_node.input_index();
+ if (input_index < 0 || input_index >= input_vector->size()) {
+ return errors::InvalidArgument("Invalid tuple input index ",
+ input_index, ": MakeTuple has ",
+ input_vector->size(), " inputs.");
+ }
+ bool release_this_input = tuple_node.release_input_handle();
+ XRTTupleAllocation::ExpandedTupleInput& input =
+ input_vector->at(input_index);
+ if (input.allocation != nullptr &&
+ (input.release_allocation_after_use || release_this_input)) {
+ return errors::InvalidArgument(
+ "Invalid tuple tree: input index ", input_index,
+ " is repeated but release_input_handle is true.");
+ }
+ if (input.allocation == nullptr) {
+ // We haven't dereferenced this handle yet.
+ TF_RET_CHECK(
+ TensorShapeUtils::IsScalar(input_tensor_list[input_index].shape()));
+ int64 key = input_tensor_list[input_index].scalar<int64>()();
+ TF_RETURN_IF_ERROR(
+ XRTTupleAllocation::Lookup(rm, key, &input.allocation));
+ input.release_allocation_after_use = release_this_input;
+ }
+ }
+ return Status::OK();
+ }
+
+ // Parses a xrt::XLATupleNode proto recursively and returns the corresponding
+ // ShapeTree where each leaf is an allocation corresponding to a handle in
+ // input_tensor_list. The ordinal of one of the allocations is returned in
+ // device_ordinal. Since it's not possible to specify a xrt::XLATupleNode with
+ // no leaves, device_ordinal will always be filled in by a successful call to
+ // ParseTupleTree.
+ static Status ParseTupleTree(
+ const xrt::XLATupleNode& tuple_tree_root,
+ const OpInputList& input_tensor_list,
+ std::vector<XRTTupleAllocation::ExpandedTupleInput>* input_vector,
+ xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>* tuple_shape_tree,
+ int* device_ordinal, ResourceMgr* rm) {
+ // First get the shape of the 'spine' of the new tuple, where every leaf is
+ // an existing allocation. As a side-effect dereference the input handles
+ // into allocations in input_vector.
+ xla::Shape tuple_tree_shape;
+ TF_RETURN_IF_ERROR(ParseTupleNode(tuple_tree_root, input_tensor_list,
+ input_vector, &tuple_tree_shape, rm));
+ // Make the shape tree of allocations where the shape is the spine and each
+ // leaf is one of the allocations looked up in input_vector. Internal nodes
+ // have nullptr allocations.
+ *tuple_shape_tree = xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput>(
+ tuple_tree_shape);
+ tuple_shape_tree->ForEachMutableElement(
+ [&](const xla::ShapeIndex& index,
+ XRTTupleAllocation::ExpandedTupleInput* element) {
+ if (tuple_shape_tree->IsLeaf(index)) {
+ // Find the matching leaf in the proto tree.
+ const xrt::XLATupleNode* tuple_node = &tuple_tree_root;
+ for (int i = 0; i < index.size(); ++i) {
+ tuple_node = &tuple_node->tuples(index[i]);
+ }
+ // Copy the appropriate input allocation to the leaf of the
+ // tuple_shape_tree.
+ int input_index = tuple_node->input_index();
+ *element = input_vector->at(input_index);
+ CHECK(element->release_allocation_after_use ==
+ tuple_node->release_input_handle());
+ // We just need to know the device_ordinal of one of the
+ // allocations. We will validate later that they are all the same.
+ *device_ordinal = (*element).allocation->device_ordinal();
+ }
+ });
+ return Status::OK();
+ }
+};
+
+// Op that allocates memory for a literal and transfers it to the device.
+template <class DeviceAccessor>
+class XRTAllocateOp : public OpKernel {
+ public:
+ explicit XRTAllocateOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTAllocateOp() override = default;
+ XRTAllocateOp(const XRTAllocateOp&) = delete;
+ XRTAllocateOp& operator=(const XRTAllocateOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTAllocateOp::Compute";
+
+ const Tensor& allocation_info = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()),
+ errors::Internal("allocation input should be a string scalar"));
+ xrt::XLAAllocation allocation_proto;
+ OP_REQUIRES(
+ ctx,
+ allocation_proto.ParseFromString(allocation_info.scalar<string>()()),
+ errors::InvalidArgument(
+ "Unable to parse allocation input to XLAAllocation"));
+
+ std::unique_ptr<xla::Literal> literal;
+ OP_REQUIRES_OK(
+ ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(ctx,
+ DeviceAccessor::InitScopedRef(
+ ctx, allocation_proto.device_ordinal(), &device_ref));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
+ *literal, device_ref.backend(),
+ device_ref.device_ordinal(), &allocation));
+
+ // Intern takes ownership of our reference to allocation.
+ int64 key;
+ OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that takes a tuple handle input and returns a handle to a sub-tuple of the
+// input.
+template <bool discard_, class DeviceAccessor>
+class XRTSubTupleOp : public OpKernel {
+ public:
+ explicit XRTSubTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTSubTupleOp() override = default;
+ XRTSubTupleOp(const XRTSubTupleOp&) = delete;
+ XRTSubTupleOp& operator=(const XRTSubTupleOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTSubTupleOp::Compute";
+
+ const Tensor& handle_tensor = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
+ errors::Internal("computation input should be an int64 scalar"));
+ int64 allocation_handle = handle_tensor.scalar<int64>()();
+
+ const Tensor& subtuple_info = ctx->input(1);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(subtuple_info.shape()),
+ errors::Internal("tuple index input should be an int32 vector"));
+ xla::ShapeIndex shape_index;
+ for (int i = 0; i < subtuple_info.dim_size(0); ++i) {
+ shape_index.push_back(subtuple_info.vec<int32>()(i));
+ }
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
+ core::ScopedUnref allocation_unref(allocation);
+
+ if (discard_) {
+ VLOG(2) << "Releasing handle " << allocation_handle;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, allocation_handle));
+ }
+
+ XRTTupleAllocation* suballocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::MakeSubBuffer(allocation, shape_index,
+ &suballocation, !discard_));
+
+ // Intern takes ownership of our reference to suballocation.
+ int64 key;
+ OP_REQUIRES_OK(ctx, suballocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that allocates memory for a literal and transfers it to the device.
+template <class DeviceAccessor>
+class XRTMakeTupleOp : public OpKernel {
+ public:
+ explicit XRTMakeTupleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTMakeTupleOp() override = default;
+ XRTMakeTupleOp(const XRTMakeTupleOp&) = delete;
+ XRTMakeTupleOp& operator=(const XRTMakeTupleOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTMakeTupleOp::Compute";
+
+ const Tensor& tuple_info = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(tuple_info.shape()),
+ errors::Internal("tuple description input should be a string scalar"));
+ xrt::XLATupleNode tuple_proto;
+ OP_REQUIRES(
+ ctx, tuple_proto.ParseFromString(tuple_info.scalar<string>()()),
+ errors::InvalidArgument("Unable to parse tuple input to XLATupleNode"));
+
+ OpInputList arg_list;
+ OP_REQUIRES_OK(ctx, ctx->input_list("input_handles", &arg_list));
+
+ // For each input, the allocation it corresponds to and a flag indicating
+ // whether or not it should be released, i.e. discarded from the resource
+ // manager. One ref on each allocation is owned by this vector, and freed on
+ // exit.
+ std::vector<XRTTupleAllocation::ExpandedTupleInput> input_vector(
+ arg_list.size());
+ auto cleanup = gtl::MakeCleanup([&input_vector] {
+ for (auto& input : input_vector) {
+ if (input.allocation != nullptr) {
+ input.allocation->Unref();
+ }
+ }
+ });
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ xla::ShapeTree<XRTTupleAllocation::ExpandedTupleInput> tuple_shape_tree;
+ // device_ordinal is filled in by ParseTupleTree with the ordinal of one of
+ // the allocations. It is guaranteed that there is at least on allocation in
+ // any legal tree. We validate below in XRTTupleAllocation::MakeTuple that
+ // all the allocations are on the same device.
+ int device_ordinal;
+ OP_REQUIRES_OK(ctx, XRTStateHelpers::ParseTupleTree(
+ tuple_proto, arg_list, &input_vector,
+ &tuple_shape_tree, &device_ordinal, rm));
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(
+ ctx, DeviceAccessor::InitScopedRef(ctx, device_ordinal, &device_ref));
+
+ XRTTupleAllocation* output_allocation;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::MakeTuple(
+ device_ref.backend(), device_ref.device_ordinal(),
+ tuple_shape_tree, &output_allocation));
+ // Add a ScopedUnref to simplify the error path while calling
+ // DeleteFromResourceManager.
+ core::ScopedUnref unref(output_allocation);
+ for (int i = 0; i < input_vector.size(); ++i) {
+ if (input_vector[i].release_allocation_after_use) {
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, arg_list[i].scalar<int64>()()));
+ }
+ }
+
+ // Intern takes ownership of a reference to output_allocation, so add
+ // another since the ScopedUnref will release one when this method exits.
+ output_allocation->Ref();
+ int64 key;
+ OP_REQUIRES_OK(ctx, output_allocation->Intern(rm, &key));
+
+ Tensor output(DT_INT64, TensorShape({}));
+ output.scalar<int64>()() = key;
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that reads a device-resident tuple to host memory and returns it as a
+// literal.
+template <bool discard_, class DeviceAccessor>
+class XRTReadLiteralOp : public OpKernel {
+ public:
+ explicit XRTReadLiteralOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTReadLiteralOp() override = default;
+ XRTReadLiteralOp(const XRTReadLiteralOp&) = delete;
+ XRTReadLiteralOp& operator=(const XRTReadLiteralOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTReadLiteralOp::Compute";
+
+ const Tensor& handle_tensor = ctx->input(0);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
+ errors::Internal("computation input should be an int64 scalar"));
+ int64 allocation_handle = handle_tensor.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ XRTTupleAllocation* allocation;
+ OP_REQUIRES_OK(
+ ctx, XRTTupleAllocation::Lookup(rm, allocation_handle, &allocation));
+ core::ScopedUnref allocation_unref(allocation);
+
+ if (discard_) {
+ VLOG(2) << "Releasing handle " << allocation_handle;
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(
+ rm, allocation_handle));
+ }
+
+ // We are guaranteed that the underlying device object won't be deleted out
+ // from under us, while the ScopedRef is live.
+ class DeviceAccessor::ScopedRef device_ref;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::InitScopedRef(
+ ctx, allocation->device_ordinal(), &device_ref));
+
+ std::unique_ptr<xla::Literal> literal;
+ OP_REQUIRES_OK(
+ ctx, allocation->ToLiteral(device_ref.backend(),
+ device_ref.device_ordinal(), &literal));
+ xla::LiteralProto literal_proto = literal->ToProto();
+
+ Tensor output(DT_STRING, TensorShape({}));
+ literal_proto.SerializeToString(&output.scalar<string>()());
+ ctx->set_output(0, output);
+ }
+};
+
+// Op that discards a handle to device memory.
+template <class DeviceAccessor>
+class XRTReleaseAllocationOp : public OpKernel {
+ public:
+ explicit XRTReleaseAllocationOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~XRTReleaseAllocationOp() override = default;
+ XRTReleaseAllocationOp(const XRTReleaseAllocationOp&) = delete;
+ XRTReleaseAllocationOp& operator=(const XRTReleaseAllocationOp&) = delete;
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "XRTReleaseAllocationOp::Compute";
+
+ const Tensor& allocation_handle = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_handle.shape()),
+ errors::Internal("handle input should be an int64 scalar"));
+ int64 key = allocation_handle.scalar<int64>()();
+
+ ResourceMgr* rm;
+ OP_REQUIRES_OK(ctx, DeviceAccessor::GetResourceManager(ctx, &rm));
+
+ OP_REQUIRES_OK(ctx, XRTTupleAllocation::DeleteFromResourceManager(rm, key));
+
+ VLOG(2) << "Released allocation handle " << key;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_KERNELS_XRT_STATE_OPS_H_
diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc
new file mode 100644
index 0000000000..5cfc8711f9
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTCompile")
+ .Input("computation: string")
+ .Output("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Reads a computation proto, compiles it, and places it in the global compilation
+cache.
+
+'computation' is a serialized xrt::XLAComputation proto.
+'handle' is an identifier that can be used in other ops to refer to the
+computation.
+)");
+
+REGISTER_OP("XRTReleaseCompilationHandle")
+ .Input("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(
+ R"(
+Discards a computation from the compilation cache. The handle cannot be
+subsequently used.
+
+'handle' is an id returned from a XRTCompile Op.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
new file mode 100644
index 0000000000..fda4c31298
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTExecute")
+ .Attr("Ninputs: int")
+ .Input("computation_handle: int64")
+ .Input("execution_config: string")
+ .Input("input_handles: Ninputs * int64")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Runs a previously-compiled computation on a core. If
+execution_config.release_input_handles is true, the input handles are invalid
+after this op runs.
+
+'computation_handle' is an id returned by XRTCompile.
+'execution_config' is a serialized xrt::TPUExecutionConfig proto.
+'input_handles' is a list of ids of allocations, one per input to the compiled
+computation.
+'output_handle' is an identifier for the result of the compiled computation.
+'Ninputs' is the number of input handles.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/ops/xrt_state_ops.cc b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
new file mode 100644
index 0000000000..07d025ce34
--- /dev/null
+++ b/tensorflow/compiler/xrt/ops/xrt_state_ops.cc
@@ -0,0 +1,122 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XRTAllocate")
+ .Input("allocation: string")
+ .Output("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Reads a literal proto and transfers it to TPU device memory.
+
+'allocation' is a serialized xrt::TPUAllocation proto.
+'handle' is an id that can be used in other ops to refer to the allocation.
+)");
+
+REGISTER_OP("XRTSubTuple")
+ .Input("base_handle: int64")
+ .Input("shape_index: int32")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a sub-tuple of an allocated tuple.
+
+'base_handle' is the id of the on-device allocation.
+'shape_index' is a vector of integers describing an XLA ShapeIndex.
+'output_handle' is an id that can be used in other ops to refer to the
+sub-tuple.
+)");
+
+REGISTER_OP("XRTSubTupleAndRelease")
+ .Input("base_handle: int64")
+ .Input("shape_index: int32")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a sub-tuple of an allocated tuple, and releases the handle
+of the input tuple.
+
+'base_handle' is the id of the on-device allocation.
+'shape_index' is a vector of integers describing an XLA ShapeIndex.
+'output_handle' is an id that can be used by other ops to refer to the
+sub-tuple.
+)");
+
+REGISTER_OP("XRTMakeTuple")
+ .Attr("Ninputs: int")
+ .Input("tuple_description: string")
+ .Input("input_handles: Ninputs * int64")
+ .Output("output_handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Returns a handle to a new allocation constructed by assembling existing
+allocations in a tuple.
+
+'tuple_description' is a serialized xrt::XLATupleNode proto describing the
+shape of the output tuple, and whether each input handle should be aliased or
+released.
+'input_handles' is a list of input handles to assemble into the output tuple.
+'output_handle' is an id that can be used by other ops to refer to the new
+tuple.
+'Ninputs' is the number of input handles.
+)");
+
+REGISTER_OP("XRTReadLiteral")
+ .Input("handle: int64")
+ .Output("literal: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Copies an allocated tuple from device memory and returns it as a literal.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+'literal' is a serialized xla::LiteralProto proto.
+)");
+
+REGISTER_OP("XRTReadLiteralAndRelease")
+ .Input("handle: int64")
+ .Output("literal: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(
+ R"(
+Copies an allocated tuple from device memory, and returns it as a literal, and
+releases the handle.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+'literal' is a serialized xla::LiteralProto proto.
+)");
+
+REGISTER_OP("XRTReleaseAllocationHandle")
+ .Input("handle: int64")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(
+ R"(
+Discards an allocation from device memory. The handle cannot be subsequently
+used.
+
+'handle' is the id returned from the Op that produced the on-device allocation.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
new file mode 100644
index 0000000000..09ab4ed95f
--- /dev/null
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -0,0 +1,65 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//learning/brain:__subpackages__",
+ "//tensorflow/compiler:__subpackages__",
+ ],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+
+cc_library(
+ name = "raw_api_test_lib",
+ testonly = 1,
+ srcs = [
+ "raw_api_test.cc",
+ ],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xrt:xrt_proto",
+ "//tensorflow/compiler/xrt:xrt_server",
+ "//tensorflow/compiler/xrt/cc:xrt_ops",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
+ name = "raw_api_test_cpu",
+ size = "medium",
+ srcs = [],
+ args = ["--xla_test_device=XLA_CPU"],
+ deps = [
+ ":raw_api_test_lib",
+ "//tensorflow/compiler/jit:xla_cpu_device",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "raw_api_test_gpu",
+ size = "medium",
+ srcs = [],
+ args = ["--xla_test_device=XLA_GPU"],
+ tags = ["requires-gpu-sm35"],
+ deps = [
+ ":raw_api_test_lib",
+ "//tensorflow/compiler/jit:xla_gpu_device",
+ ],
+)
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
new file mode 100644
index 0000000000..5b8516bf1d
--- /dev/null
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -0,0 +1,421 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
+#include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
+#include "tensorflow/compiler/xrt/xrt.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace {
+
+string* xla_test_device_ptr; // initial value set in main()
+
+string DeviceFromFlag() {
+ string xla_test_device = *xla_test_device_ptr;
+ return absl::StrCat("/device:", xla_test_device, ":0");
+}
+
+xla::LiteralProto TwoElementTuple() {
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ return tuple->ToProto();
+}
+
+xla::LiteralProto ScalarLiteral() {
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ return scalar->ToProto();
+}
+
+xla::LiteralProto NestedTuple() {
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ auto nested = xla::LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
+ return nested->ToProto();
+}
+
+xla::LiteralProto MakeTuple0() {
+ auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
+ auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
+ auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
+ auto tuple = xla::LiteralUtil::MakeTuple({array.get(), matrix.get()});
+ auto nested0 = xla::LiteralUtil::MakeTuple({scalar.get(), tuple.get()});
+ auto nested1 = xla::LiteralUtil::MakeTuple({scalar.get(), nested0.get()});
+ return nested1->ToProto();
+}
+
+xla::LiteralProto FloatVector(gtl::ArraySlice<float> v) {
+ auto array = xla::LiteralUtil::CreateR1<float>(v);
+ return array->ToProto();
+}
+
+bool CompareLiteralProtos(const xla::LiteralProto& a,
+ const xla::LiteralProto& b) {
+ auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
+ auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
+ bool equal = *l_a == *l_b;
+ if (!equal) {
+ LOG(INFO) << "LiteralProtos don't match " << a.DebugString()
+ << " != " << b.DebugString();
+ }
+ return equal;
+}
+
+bool CompareLiteralToLiteralProto(const xla::Literal& a,
+ const xla::LiteralProto& b) {
+ auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
+ bool equal = a == *l_b;
+ if (!equal) {
+ LOG(INFO) << "Literal and LiteralProto don't match "
+ << a.ToProto().DebugString() << " != " << b.DebugString();
+ }
+ return equal;
+}
+
+xla::XlaComputation AddAndScale() {
+ xla::XlaBuilder builder("AddAndScale");
+ auto p0 = xla::Parameter(&builder, 0,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
+ auto p1 = xla::Parameter(&builder, 1,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
+ auto sum = xla::Add(p0, p1);
+ auto c = xla::ConstantR0<float>(&builder, 3.0f);
+ xla::Mul(sum, c);
+ return builder.Build().ValueOrDie();
+}
+
+xla::XlaComputation AddAndTuple() {
+ xla::XlaBuilder builder("AddAndTuple");
+ auto p0 = xla::Parameter(&builder, 0,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
+ auto p1 = xla::Parameter(&builder, 1,
+ xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
+ auto sum = xla::Add(p0, p1);
+ xla::Tuple(&builder, {sum});
+ return builder.Build().ValueOrDie();
+}
+
+void StoreComputationSnapshot(const xla::XlaComputation& computation,
+ xla::HloSnapshot* dst) {
+ auto snapshot = computation.Snapshot().ValueOrDie();
+ *dst = *snapshot;
+}
+
+TEST(RawApiTest, ReadAndWriteState) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = TwoElementTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto handle = ops::XRTAllocate(root, value);
+ auto read_back = ops::XRTReadLiteral(root, handle);
+ auto release = ops::XRTReleaseAllocationHandle(
+ root.WithControlDependencies(read_back), handle);
+ TF_ASSERT_OK(root.status());
+
+ tensorflow::ClientSession session(root);
+ std::vector<tensorflow::Tensor> outputs;
+ TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back},
+ {release}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
+}
+
+TEST(RawApiTest, ReadAndWriteStateAutoFree) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = TwoElementTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto handle = ops::XRTAllocate(root, value);
+ auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
+}
+
+TEST(RawApiTest, SubBuffer) {
+ xrt::XLAAllocation alloc;
+ alloc.set_device_ordinal(0);
+ *alloc.mutable_value() = NestedTuple();
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
+ auto base_handle = ops::XRTAllocate(root, value);
+ auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
+ auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
+ auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
+ auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
+ auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
+ auto sub_00 = ops::XRTSubTupleAndRelease(
+ root.WithControlDependencies(
+ {sub_0.output_handle.op(), sub_1.output_handle.op()}),
+ base_handle, index_00);
+ auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
+ auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
+ auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
+
+ auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
+ auto base_elements = base_literal->DecomposeTuple();
+ auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
+ xla::LiteralProto response_0;
+ EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
+ xla::LiteralProto response_1;
+ EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
+ xla::LiteralProto response_00;
+ EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar<string>()()));
+ EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
+}
+
+TEST(RawApiTest, MakeTuple) {
+ xrt::XLAAllocation alloc_0;
+ alloc_0.set_device_ordinal(0);
+ *alloc_0.mutable_value() = TwoElementTuple();
+ xrt::XLAAllocation alloc_1;
+ alloc_1.set_device_ordinal(0);
+ *alloc_1.mutable_value() = ScalarLiteral();
+
+ // The trivial tuple that just forwards its input and releases it.
+ xrt::XLATupleNode desc_0;
+ desc_0.set_input_index(0);
+ desc_0.set_release_input_handle(true);
+
+ xrt::XLATupleNode desc_1;
+ auto subdesc_10 = desc_1.add_tuples();
+ auto subdesc_11 = desc_1.add_tuples();
+ subdesc_10->set_input_index(0);
+ auto subdesc_110 = subdesc_11->add_tuples();
+ subdesc_110->set_input_index(0);
+ auto subdesc_111 = subdesc_11->add_tuples();
+ subdesc_111->set_input_index(1);
+
+ xrt::XLATupleNode desc_2;
+ auto subdesc_20 = desc_2.add_tuples();
+ auto subdesc_21 = desc_2.add_tuples();
+ subdesc_20->set_input_index(1);
+ subdesc_20->set_release_input_handle(true);
+ subdesc_21->set_input_index(0);
+ subdesc_21->set_release_input_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto value_0 =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
+ auto handle_0 = ops::XRTAllocate(root, value_0);
+ auto value_1 =
+ ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
+ auto handle_1 = ops::XRTAllocate(root, value_1);
+ auto tuple_0 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
+ auto handle_2 =
+ ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
+ // handle_0 has now been released.
+ auto tuple_1 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
+ auto handle_3 = ops::XRTMakeTuple(
+ root, tuple_1,
+ {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
+ auto tuple_2 =
+ ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
+ // Make sure this runs after handle_3 has completed, since it will free
+ // handle_1 and handle_2.
+ auto handle_4 = ops::XRTMakeTuple(
+ root.WithControlDependencies(handle_3), tuple_2,
+ {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
+ // handle_1 and handle_2 have now been released.
+
+ auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
+ auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
+ xla::LiteralProto response_0;
+ EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
+ xla::LiteralProto response_1;
+ EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
+
+ auto expected_0 = MakeTuple0();
+ EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
+ auto expected_1 = NestedTuple();
+ EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
+}
+
+TEST(RawApiTest, CompileAndExecute) {
+ xrt::XLAAllocation p0;
+ p0.set_device_ordinal(0);
+ *p0.mutable_value() = FloatVector({1.0f, 2.0f});
+ xrt::XLAAllocation p1;
+ p1.set_device_ordinal(0);
+ *p1.mutable_value() = FloatVector({8.0f, 5.0f});
+
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto p0_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
+ auto p0_handle = ops::XRTAllocate(root, p0_value);
+ auto p1_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
+ auto p1_handle = ops::XRTAllocate(root, p1_value);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ {Output(p0_handle), Output(p1_handle)});
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+}
+
+TEST(RawApiTest, CompileAndExecuteReturnTuple) {
+ xrt::XLAAllocation p0;
+ p0.set_device_ordinal(0);
+ *p0.mutable_value() = FloatVector({1.0f, 2.0f});
+ xrt::XLAAllocation p1;
+ p1.set_device_ordinal(0);
+ *p1.mutable_value() = FloatVector({8.0f, 5.0f});
+
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
+ *shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::F32, {2})});
+ StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto p0_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
+ auto p0_handle = ops::XRTAllocate(root, p0_value);
+ auto p1_value =
+ ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
+ auto p1_handle = ops::XRTAllocate(root, p1_value);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ {Output(p0_handle), Output(p1_handle)});
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
+ auto expected = xla::LiteralUtil::MakeTuple({sum.get()});
+ EXPECT_TRUE(CompareLiteralToLiteralProto(*expected, response));
+}
+
+} // namespace
+
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
+ "Tensorflow device type to use for test, e.g., XLA_CPU"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto
new file mode 100644
index 0000000000..5678f0905f
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt.proto
@@ -0,0 +1,78 @@
+syntax = "proto3";
+
+package xrt;
+
+import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/compiler/xla/service/hlo.proto";
+
+// Options for an XLA compilation.
+message XLAComputationConfig {
+ // The number of replicas the computation will be run on. If this is
+ // default (0) it is interpreted as 1.
+ int32 num_replicas = 1;
+ // The number of "model-parallel" cores per replica. If this is
+ // default (0) it is interpreted as 1.
+ int32 num_cores_per_replica = 2;
+ // Optional metadata about host sends and recvs.
+ tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
+
+ // The arg/result shapes for the whole computation.
+ xla.ProgramShape program_shape = 4;
+ // The arg/result shapes for each core of a model-parallel
+ // computation. per_core_args_and_result_shapes is optional for a
+ // single-core computation.
+ repeated xla.ProgramShape per_core_program_shape = 5;
+}
+
+// Options and XLA computation for a compilation.
+message XLAComputation {
+ XLAComputationConfig config = 1;
+ xla.HloSnapshot hlo_snapshot = 2;
+}
+
+// Literal to allocate space for, and transfer to, device memory.
+message XLAAllocation {
+ int32 device_ordinal = 1;
+ xla.LiteralProto value = 2;
+}
+
+// Node in a tree describing a tuple constructed from input handles. A
+// node is an internal node if tuples is non-empty, in which case
+// input_index and release_input_handle are ignored. Otherwise a node
+// is a leaf node. Each leaf XLATupleNode is the index of an input
+// which corresponds to a handle that will be grafted onto the output
+// tuple at that location. If release_input_handle is true that input
+// handle will be released and become invalid. Inputs may be repeated
+// in which case leaves of the output tuple will alias. If an input is
+// repeated, release_input_handle must be false for every leaf where
+// that input appears.
+//
+// For example, if input 0 has shape {} and input 1 has shape {2,3}
+// then the XLATupleNode with structure {1,{0,1}} corresponds to a
+// tuple with shape {{2,3},{{},{2,3}}}.
+message XLATupleNode {
+ int32 input_index = 1;
+ bool release_input_handle = 2;
+ repeated XLATupleNode tuples = 3;
+}
+
+// Options for an XLA execution.
+message XRTExecutionConfig {
+ // Local device to run on. This is present because the execute Op
+ // may be placed on a device such as CPU or TPU_SYSTEM that
+ // logically manages multiple cores.
+ int32 device_ordinal = 1;
+ // Which model-parallel computation to run from the compiled bundle.
+ int32 core_index_in_replica = 2;
+ // Optional key to disambiguate between executions. This is only
+ // needed if multiple host send/recvs may be outstanding
+ // concurrently with executions.
+ string execution_instance_key = 3;
+ // If non-zero, rng_seed to reset the core with.
+ uint32 rng_seed = 4;
+ // If true, release allocation handles on the inputs after running.
+ bool release_input_handles = 5;
+ // If true, release the handle to the computation after running.
+ bool release_compilation_handle = 6;
+}
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.cc b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
new file mode 100644
index 0000000000..4844c7fb71
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.cc
@@ -0,0 +1,263 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+const char* kXRTCompilationCacheResourceName = "xrt_compilation_cache";
+
+XRTCompilationCache::EntryRefImpl::EntryRefImpl(XRTCompilationCache* parent,
+ CompiledSubgraph* entry)
+ : parent_(parent), entry_(entry) {
+ entry_->Ref();
+}
+
+XRTCompilationCache::EntryRefImpl::~EntryRefImpl() {
+ parent_->DiscardEntryRef(entry_);
+}
+
+XRTCompilationCacheEntry XRTCompilationCache::EntryRefImpl::get() {
+ return XRTCompilationCacheEntry(entry_->program.get());
+}
+
+XRTCompilationCache::XRTCompilationCache(int max_number_of_entries)
+ : max_cache_entries_(max_number_of_entries) {
+ CHECK_GE(max_cache_entries_, 0);
+ VLOG(1) << "Created compilation cache max " << max_cache_entries_
+ << " entries.";
+}
+
+XRTCompilationCache::~XRTCompilationCache() {
+ VLOG(1) << "XRTCompilationCache::~XRTCompilationCache()";
+ while (!entries_by_last_use_.empty()) {
+ MarkOldestEntryForEviction();
+ }
+ // By the time the cache is deleted all reference holders should have already
+ // been deleted, since they were holding references to the cache. So all
+ // entries should be gone at this point.
+ CHECK_EQ(cache_.size(), 0);
+ CHECK_EQ(entries_by_uid_.size(), 0);
+ CHECK_EQ(cache_entries_, 0);
+ CHECK_EQ(marked_for_eviction_entries_, 0);
+}
+
+Status XRTCompilationCache::Release(int64 uid) {
+ absl::MutexLock lock(&mu_);
+ auto iter = entries_by_uid_.find(uid);
+
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No cache entry found for uid ", uid);
+ }
+
+ DiscardEntryRefLocked(iter->second);
+
+ VLOG(1) << "After releasing entry " << uid << " refs cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_
+ << "), marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+
+ return Status::OK();
+}
+
+void XRTCompilationCache::DiscardEntryRef(CompiledSubgraph* entry) {
+ absl::MutexLock lock(&mu_);
+ DiscardEntryRefLocked(entry);
+}
+
+void XRTCompilationCache::DiscardEntryRefLocked(CompiledSubgraph* entry) {
+ if (entry->RefCountIsOne()) {
+ // The last reference to this entry is going away, so really delete it from
+ // the cache in such a way that it can't be restored by being looked up
+ // again.
+
+ // Sanity-check that it has been marked for eviction.
+ CHECK(entries_by_last_use_.find(entry->last_use) ==
+ entries_by_last_use_.end());
+ // Update the counter tracking how much space is taken up by entries that
+ // are marked for eviction.
+ --marked_for_eviction_entries_;
+
+ // Remove the entry from the cache.
+ auto erased = cache_.erase(entry->key);
+ if (erased == 0) {
+ LOG(FATAL) << "Tried to discard nonexistent cache entry";
+ }
+ erased = entries_by_uid_.erase(entry->uid);
+ CHECK_EQ(erased, 1);
+ }
+ entry->Unref();
+}
+
+void XRTCompilationCache::MarkOldestEntryForEviction() {
+ CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
+ VLOG(1) << "Marking " << entry_to_mark->key << " for eviction";
+ entries_by_last_use_.erase(entry_to_mark->last_use);
+ --cache_entries_;
+ ++marked_for_eviction_entries_;
+ // Discard the cache's reference to entry. If steps are holding onto
+ // references to entry it won't be deleted until the last step holding it
+ // completes. It stays in the cache in the meantime and can be resurrected
+ // by a call to CompileIfKeyAbsent if that occurs before the last reference
+ // expires.
+ DiscardEntryRefLocked(entry_to_mark);
+}
+
+void XRTCompilationCache::LookupEntryMarkedForEviction(
+ CompiledSubgraph* entry) {
+ // The entry was previously marked for eviction (or is newly created) so
+ // unmark it. Add a reference (owned by the cache), update the cache size, and
+ // mark something old for eviction if necessary.
+ entry->Ref();
+ --marked_for_eviction_entries_;
+ ++cache_entries_;
+
+ // Mark the least-recently-used non-marked entry for eviction. Never mark the
+ // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
+ // which means there's only one entry not already marked for eviction), so
+ // that an entry persists in the cache even if it is larger than the allocated
+ // cache size.
+ while (entries_by_last_use_.size() > 1 &&
+ cache_entries_ > max_cache_entries_) {
+ MarkOldestEntryForEviction();
+ }
+}
+
+XRTCompilationCache::CompiledSubgraph* XRTCompilationCache::InitializeEntry(
+ const string& key,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ initialize_program) {
+ CompiledSubgraph* entry = new CompiledSubgraph();
+ entry->parent = this;
+ entry->key = key;
+ entry->uid = next_uid_++;
+ // Add the entry to the cache. Once the computation has been compiled,
+ // UpdateEntryAfterCompilation will be called to potentially mark old entries
+ // that don't fit any more for eviction.
+ //
+ // At this point there is one reference to entry, which is owned by the caller
+ // who created the entry. A second reference, owned by the cache, will be
+ // added below since we leave the entry in the 'marked for eviction' state
+ // here.
+ auto cache_inserted =
+ cache_.insert(std::pair<string, CompiledSubgraph*>(key, entry));
+ CHECK(cache_inserted.second);
+
+ // Initialize the program outside the lock so that other cache operations
+ // can proceed during the (potentially lengthy) initialization.
+ Status s;
+ std::unique_ptr<xla::LocalExecutable> program;
+ {
+ mu_.Unlock();
+ { s = initialize_program(&program); }
+ mu_.Lock();
+ }
+
+ // Add the entry to the uid index.
+ auto uid_inserted = entries_by_uid_.insert(
+ std::pair<int64, CompiledSubgraph*>(entry->uid, entry));
+ CHECK(uid_inserted.second);
+
+ entry->initialized = true;
+ entry->initialization_status = s;
+ if (s.ok()) {
+ entry->program = std::move(program);
+ }
+ // Add the entry to marked_for_eviction_entries_ since it will be adjusted
+ // down again when the newly-created entry gets unmarked.
+ ++marked_for_eviction_entries_;
+ return entry;
+}
+
+Status XRTCompilationCache::CompileIfKeyAbsent(
+ const string& key, int64* uid,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ compile_function) {
+ CompiledSubgraph* entry = nullptr;
+
+ absl::MutexLock lock(&mu_);
+ auto iter = cache_.find(key);
+
+ if (iter == cache_.end()) {
+ // The single ref on the newly-created entry is owned by the caller.
+ VLOG(1) << "Before adding new entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+ entry = InitializeEntry(key, compile_function);
+ } else {
+ VLOG(1) << "Before refreshing entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+ entry = iter->second;
+ // Make a new reference that is owned by the caller.
+ entry->Ref();
+ // Block if necessary until the subgraph has been initialized.
+ mu_.Await(absl::Condition(
+ +[](CompiledSubgraph* e) { return e->initialized; }, entry));
+ }
+
+ // Let the caller know the uid of the entry.
+ *uid = entry->uid;
+
+ // Remove the old LRU-table entry if it wasn't already marked for eviction.
+ auto erased = entries_by_last_use_.erase(entry->last_use);
+ // Update the LRU table indicating this entry is the most recently used.
+ entry->last_use = use_counter_++;
+ entries_by_last_use_[entry->last_use] = entry;
+ if (erased == 0) {
+ // The entry had been marked for eviction, or is newly created.
+ LookupEntryMarkedForEviction(entry);
+ }
+
+ VLOG(1) << "After refreshing entry for key " << key << " cache is "
+ << cache_.size() << " entries ("
+ << cache_entries_ + marked_for_eviction_entries_ << "), "
+ << " marked for eviction "
+ << (cache_.size() - entries_by_last_use_.size()) << " entries ("
+ << marked_for_eviction_entries_ << ").";
+
+ return entry->initialization_status;
+}
+
+Status XRTCompilationCache::Lookup(
+ int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
+ entry->reset();
+
+ absl::MutexLock lock(&mu_);
+ const auto iter = entries_by_uid_.find(uid);
+ if (iter == entries_by_uid_.end()) {
+ return errors::NotFound("No executable found for uid ", uid);
+ }
+ CompiledSubgraph* cache_entry = iter->second;
+ *entry = std::unique_ptr<XRTCompilationCacheEntryRef>(
+ new EntryRefImpl(this, cache_entry));
+ return Status::OK();
+}
+
+string XRTCompilationCache::DebugString() { return "XRTCompilationCache"; }
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_compilation_cache.h b/tensorflow/compiler/xrt/xrt_compilation_cache.h
new file mode 100644
index 0000000000..c505299a45
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_compilation_cache.h
@@ -0,0 +1,238 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/refcount.h"
+
+namespace tensorflow {
+
+extern const char* kXRTCompilationCacheResourceName;
+
+struct XRTCompilationCacheEntry {
+ explicit XRTCompilationCacheEntry(xla::LocalExecutable* executable)
+ : executable(executable) {}
+
+ // Returns a non-owned pointer to an immutable executable.
+ xla::LocalExecutable* get_executable() const { return executable; }
+
+ private:
+ xla::LocalExecutable* executable;
+};
+
+// Base class for a reference to a cached executable. A unique_ptr to a
+// XRTCompilationCacheEntryRef is returned by the cache Lookup methods below,
+// and ensures the underlying executable is not garbage-collected until the
+// client discards the ptr.
+class XRTCompilationCacheEntryRef {
+ public:
+ virtual ~XRTCompilationCacheEntryRef() = default;
+
+ // Returns a XRTCompilationCacheEntry that should not be used beyond the
+ // lifetime of the XRTCompilationCacheEntryRef.
+ virtual XRTCompilationCacheEntry get() = 0;
+};
+
+// Cache for compiled XLA executables.
+// TODO(b/112646171) rationalize this with the other compilation caches.
+//
+// Each key identifies a unique XLA computation, and the value is executable
+// generated by compiling the computation.
+//
+// When a computation is considered for compilation, the client calls
+//
+// auto key = <compute key for computation>;
+// auto compile_function = <lambda to compile computation into executable>;
+// int64 uid;
+// CompileIfKeyAbsent(computation_key, &uid, compile_function);
+//
+// where computation_key is the key computed for the computation. On success,
+// uid contains an identifier that can be used to look up the executable. If the
+// compiled executable were not present in the cache, compile_function would be
+// called to generate it.
+//
+// The caller is responsible for calling Release(uid) once for every
+// call to CompileIfKeyAbsent(key, ...) to discard the reference to the
+// compilation results, after the caller is sure it will not look up the
+// compiled executables again.
+//
+// Subsequently the client can call
+//
+// std::unique_ptr<XRTCompilationCacheEntryRef> entry;
+// Lookup(uid, &entry);
+// auto proto = entry->get();
+//
+// to access a cached executable.
+class XRTCompilationCache : public ResourceBase {
+ public:
+ // There is no way in general to discover the size taken by an XLA executable,
+ // so the cache defaults to a specific number of entries to determine when to
+ // start evicting programs. TODO(b/112592410) change this if the XLA API gets
+ // a mechanism to query size.
+ explicit XRTCompilationCache(int max_number_of_entries);
+ ~XRTCompilationCache() override;
+
+ // Ensures there is an entry for key present in the cache. By the time
+ // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
+ // for key, and that entry will remain valid at least until Release is called
+ // on the returned uid. The first call to CompileIfKeyAbsent with a key that
+ // is not in the cache will evaluate compile_function to compute the value to
+ // use in the entry. Subsequent calls with the same key will block until
+ // compile_function completes. Other cache reads and inserts may proceed on
+ // other threads while compile_function is executing. The caller is
+ // responsible for calling Release(uid) to manually discard its reference to
+ // the compiled program, once the caller will not look up the compiled program
+ // again.
+ //
+ // compile_function should compile the computation represented by key and fill
+ // the xla::LocalExecutable into its passed argument. It should return OK
+ // if and only if compilation succeeds. The executable will be discarded on
+ // non-OK status.
+ Status CompileIfKeyAbsent(
+ const string& key, int64* uid,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ compile_function);
+
+ Status Release(int64 uid);
+
+ // Looks up an executable corresponding to uid. On success a pointer to an
+ // EntryRef holding the program is returned in entry.
+ Status Lookup(int64 uid, std::unique_ptr<XRTCompilationCacheEntryRef>* entry);
+
+ string DebugString() override;
+
+ private:
+ // An entry in the compilation cache. The entry is deleted once it has been
+ // marked for eviction from the cache _and_ all looked-up entries have been
+ // released. When the entry is first created, it is uninitialized and a
+ // client-supplied compilation function is run outside the cache's lock to
+ // generate the program to be stored in the entry. Any other client that
+ // requests the entry will block until it has been initialized. Each entry has
+ // a last_use value that set from a monotonically-increasing counter in the
+ // cache whenever the entry is referenced. When the cache becomes full,
+ // entries are marked for eviction in LRU order.
+ struct CompiledSubgraph : public core::RefCounted {
+ ~CompiledSubgraph() override = default;
+
+ XRTCompilationCache* parent = nullptr; // Not owned.
+ bool initialized = false;
+ // The Status returned by the compilation function when the entry is
+ // initialized. This status will be returned to any client that requests the
+ // entry.
+ Status initialization_status;
+ // Counter to keep track of LRU entries for the eviction policy.
+ int64 last_use = -1;
+ // The unique key describing this entry.
+ string key;
+ // The uid describing this entry.
+ int64 uid;
+ // The compiled payload corresponding to the key.
+ std::unique_ptr<xla::LocalExecutable> program;
+ };
+
+ // Wrapper for a cache entry that holds a reference to the entry until the
+ // wrapper is deleted. This wrapper is the concrete type of
+ // XRTCompilationCacheEntryRef returned by Lookup.
+ class EntryRefImpl : public XRTCompilationCacheEntryRef {
+ public:
+ EntryRefImpl(XRTCompilationCache* parent, CompiledSubgraph* entry);
+ ~EntryRefImpl() override;
+
+ XRTCompilationCacheEntry get() override;
+
+ private:
+ XRTCompilationCache* parent_; // Not owned.
+ // A reference to entry_ is acquired in the contructor and released via
+ // parent->DiscardEntryRef in the destructor.
+ CompiledSubgraph* entry_;
+ };
+
+ // Releases one reference to entry. This is called by the cache when entry is
+ // marked for eviction; or by an EntryRefImpl when it is destroyed. Before the
+ // last reference to entry is released, entry is removed from cache_.
+ void DiscardEntryRef(CompiledSubgraph* entry);
+ void DiscardEntryRefLocked(CompiledSubgraph* entry)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Marks the oldest unmarked entry for eviction. Requires that there is at
+ // least one such entry.
+ void MarkOldestEntryForEviction() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Updates datastructures to indicate that entry, which had been marked for
+ // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
+ // entry is newly created, or an entry that has been marked for eviction but
+ // not yet evicted is looked up.
+ //
+ // First the entry is unmarked for eviction, i.e. the cache gains a reference
+ // to entry, entry's last_use field is set to be the most recent value of
+ // use_counter_ and entries_by_last_use_ is updated accordingly.
+ //
+ // Next, the size of the cache is examined to see if any other entries need to
+ // be marked for eviction now that entry has been unmarked. While the total
+ // number of unmarked cached entries is greater than max_cache_entries_,
+ // entries are marked for eviction in LRU order. The most recently used entry
+ // is never marked for eviction, so an entry larger than the max cache entries
+ // will remain in the cache until it is replaced by something else.
+ void LookupEntryMarkedForEviction(CompiledSubgraph* entry)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Creates a new entry by running initialize_program and places it in the
+ // cache to be looked up by key. The new entry is in the 'marked for eviction'
+ // state (not present in entries_by_last_use_) and the caller is expected to
+ // call LookupEntryMarkedForEviction after InitializeEntry.
+ //
+ // **InitializeEntry releases mu_ during the call to initialize_program.**
+ CompiledSubgraph* InitializeEntry(
+ const string& key,
+ const std::function<Status(std::unique_ptr<xla::LocalExecutable>*)>&
+ initialize_program) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // The maximum number of entries that are stored in the cache before entries
+ // are marked for eviction.
+ const int max_cache_entries_;
+
+ mutable absl::Mutex mu_;
+ // The uid to assign to the next new entry created.
+ int64 next_uid_ GUARDED_BY(mu_) = 0;
+ // The total number of entries that are stored and not marked for eviction.
+ int cache_entries_ GUARDED_BY(mu_) = 0;
+ // The total number of entries that are marked for eviction.
+ int marked_for_eviction_entries_ GUARDED_BY(mu_) = 0;
+ // The value to assign to the last_use field of the next entry that is looked
+ // up.
+ int64 use_counter_ GUARDED_BY(mu_) = 0;
+ // All the executables that can be looked up in the cache index by key. An
+ // entry is marked for eviction iff it is present in cache_ and not in
+ // entries_by_last_use_.
+ std::unordered_map<string, CompiledSubgraph*> cache_ GUARDED_BY(mu_);
+ // All the executable entries that can be looked up in the cache indexed by
+ // uid.
+ std::unordered_map<int64, CompiledSubgraph*> entries_by_uid_ GUARDED_BY(mu_);
+ // Map from last_use to entry, used to mark entries for eviction in LRU
+ // order. If an entry's last_use counter is not present as a key in
+ // entries_by_last_use_ then the entry has been marked for eviction.
+ std::map<int64, CompiledSubgraph*> entries_by_last_use_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/xrt/xrt_device.cc b/tensorflow/compiler/xrt/xrt_device.cc
new file mode 100644
index 0000000000..ea40e6c895
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_device.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for managing access to XLA resources.
+
+#include "tensorflow/compiler/xrt/xrt_device.h"
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+/*static*/ Status XRTGenericDeviceAccessor::GetResourceManager(
+ OpKernelContext* ctx, ResourceMgr** rm) {
+ *rm = ctx->resource_manager();
+ return Status::OK();
+}
+
+/*static*/ Status XRTGenericDeviceAccessor::InitScopedRef(
+ OpKernelContext* ctx, int device_ordinal, ScopedRef* scoped_ref) {
+ const XlaDevice::Metadata* metadata;
+ TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata));
+ if (device_ordinal != metadata->device_ordinal()) {
+ return errors::Internal("XRT device ordinal requested ", device_ordinal,
+ " on device with ordinal ",
+ metadata->device_ordinal());
+ }
+ scoped_ref->Acquire(metadata->client());
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_device.h b/tensorflow/compiler/xrt/xrt_device.h
new file mode 100644
index 0000000000..1e3fddd2a7
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_device.h
@@ -0,0 +1,66 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for keeping track of on-device state.
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace tensorflow {
+
+// This accessor is used for XLA CPU/GPU. It uses the device resource manager,
+// so e.g., on multi-GPU setups the compilation cache will not be shared across
+// devices.
+class XRTGenericDeviceAccessor {
+ public:
+ static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm);
+
+ // We use a ScopedRef pattern here even though it's not strictly necessary,
+ // just so that templated uses of this and the TPU accessor class will be as
+ // similar as possible.
+ class ScopedRef {
+ public:
+ ScopedRef() {}
+ ~ScopedRef() {}
+
+ ScopedRef(const ScopedRef&) = delete;
+ ScopedRef& operator=(const ScopedRef&) = delete;
+
+ // Returns the XLA device protected by this ScopedRef.
+ xla::LocalClient* client() { return client_; }
+ xla::Backend* backend() { return client_->mutable_backend(); }
+ int device_ordinal() { return 0; }
+
+ private:
+ // XRTGenericDeviceAccessor::InitScopedRef is the only way to initialize
+ // ScopedRef.
+ friend class XRTGenericDeviceAccessor;
+
+ void Acquire(xla::LocalClient* client) { client_ = client; }
+
+ xla::LocalClient* client_ = nullptr;
+ };
+
+ static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal,
+ ScopedRef* scoped_ref);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_DEVICE_H_
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
new file mode 100644
index 0000000000..911ac9a78b
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -0,0 +1,458 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for allocating XLA literals in device memory and managing handles
+// that refer to them.
+
+#include "tensorflow/compiler/xrt/xrt_state.h"
+
+#include <stdint.h>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+namespace {
+
+const char* kTupleContainer = "tuples";
+
+// Counter used to assign unique handles.
+mutex _uid_mutex(tensorflow::LINKER_INITIALIZED);
+int64 _uid GUARDED_BY(_uid_mutex) = 0;
+int64 get_uid() {
+ mutex_lock l(_uid_mutex);
+ return _uid++;
+}
+
+Status AllocateScopedShapedBuffer(
+ xla::Backend* backend, int device_ordinal, const xla::Shape& shape,
+ std::unique_ptr<xla::ScopedShapedBuffer>* buffer) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+
+ // XLA may use a different representation on device than the representation on
+ // the host. XLA does not document any contract for the relationship between
+ // these representations :/ Right now, the device shape is always a superset
+ // of the host shape, meaning that for any valid ShapeIndex in the host shape
+ // that ShapeIndex is also valid in the device shape, but not vice versa. In
+ // particular, some host-side types are rewritten to be tuples. We rely on
+ // this property when making sub-buffers, because we assume that if the client
+ // requests the host-shape sub-buffer at index i, that will correspond to the
+ // right device-shape sub-buffer at the same index.
+ xla::Shape on_device_shape = transfer_manager->HostShapeToDeviceShape(shape);
+
+ // The ScopedShapedBuffer frees the buffers that have so far been allocated if
+ // it goes out of scope. That's useful if we return early as the result of an
+ // error allocating one of the later buffers.
+ *buffer = absl::make_unique<xla::ScopedShapedBuffer>(
+ shape, on_device_shape, allocator, device_ordinal);
+ for (auto& index_to_buffer : (*buffer)->buffers()) {
+ xla::Shape subshape =
+ xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first);
+ uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
+ TF_ASSIGN_OR_RETURN(
+ xla::OwningDeviceMemory buffer,
+ allocator->Allocate(device_ordinal, size, /*retry_on_failure=*/false));
+ // Move our buffer into shaped_buffer, which takes ownership of it.
+ index_to_buffer.second = buffer.Forget();
+ VLOG(2) << "Allocated buffer at " << index_to_buffer.second.opaque()
+ << " index " << index_to_buffer.first.ToString();
+ }
+
+ TF_RETURN_IF_ERROR(
+ transfer_manager->WriteTupleIndexTables(stream.get(), *(buffer->get())));
+
+ return Status::OK();
+}
+
+} // namespace
+
+XRTBufferAllocation::XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
+ int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator)
+ : allocation_(allocation),
+ device_ordinal_(device_ordinal),
+ allocator_(allocator) {}
+
+XRTBufferAllocation::~XRTBufferAllocation() {
+ // Deallocate explicitly allows allocation_ to be null.
+ Status s = allocator_->Deallocate(device_ordinal_, allocation_);
+ // Nothing to do but check fail here if memory datastructures are corrupted.
+ CHECK(s.ok());
+ VLOG(2) << "Freed buffer at " << allocation_.opaque();
+}
+
+const se::DeviceMemoryBase& XRTBufferAllocation::allocation() {
+ return allocation_;
+}
+
+void XRTBufferAllocation::DiscardAllocation() {
+ // Replace the allocation with a null.
+ allocation_ = se::DeviceMemoryBase();
+}
+
+XRTTupleAllocation::XRTTupleAllocation(int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator,
+ const xla::Shape& on_host_shape,
+ const xla::Shape& on_device_shape)
+ : device_ordinal_(device_ordinal),
+ allocator_(allocator),
+ on_host_shape_(on_host_shape),
+ on_device_shape_(on_device_shape),
+ buffers_(&on_device_shape_) {}
+
+XRTTupleAllocation::~XRTTupleAllocation() {
+ for (auto& buffer : buffers_) {
+ buffer.second->Unref();
+ }
+}
+
+/*static*/ Status XRTTupleAllocation::CreateAndTransfer(
+ const xla::Literal& literal, xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+
+ std::unique_ptr<xla::ScopedShapedBuffer> scoped_buffer;
+ TF_RETURN_IF_ERROR(AllocateScopedShapedBuffer(
+ backend, device_ordinal, literal.shape(), &scoped_buffer));
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+ TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
+ stream.get(), literal, *scoped_buffer));
+
+ // By releasing the ScopedShapedBuffer we ensure that the underlying storage
+ // won't be freed when the buffer goes out of scope at the end of this
+ // call. To avoid a leak, there must be no error-case returns from here until
+ // the end of the method.
+ auto shaped_buffer = scoped_buffer->release();
+ *allocation = new XRTTupleAllocation(device_ordinal, allocator,
+ shaped_buffer.on_host_shape(),
+ shaped_buffer.on_device_shape());
+ (*allocation)
+ ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::CreateFromBuffer(
+ const xla::ShapedBuffer& shaped_buffer, xla::Backend* backend,
+ int device_ordinal, XRTTupleAllocation** allocation) {
+ auto allocator = backend->memory_allocator();
+
+ *allocation = new XRTTupleAllocation(device_ordinal, allocator,
+ shaped_buffer.on_host_shape(),
+ shaped_buffer.on_device_shape());
+ (*allocation)
+ ->InitializeFromShapedBuffer(shaped_buffer, allocator, device_ordinal);
+ return Status::OK();
+}
+
+Status XRTTupleAllocation::ToLiteral(xla::Backend* backend, int device_ordinal,
+ std::unique_ptr<xla::Literal>* literal) {
+ auto transfer_manager = backend->transfer_manager();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+ TF_ASSIGN_OR_RETURN(*literal, transfer_manager->TransferLiteralFromDevice(
+ stream.get(), ToShapedBuffer()));
+ return Status::OK();
+}
+
+void XRTTupleAllocation::DiscardAllocation(
+ const xla::ShapeIndex& buffer_index) {
+ buffers_.element(buffer_index)->DiscardAllocation();
+}
+
+const xla::Shape& XRTTupleAllocation::on_host_shape() { return on_host_shape_; }
+
+const xla::Shape& XRTTupleAllocation::on_device_shape() {
+ return on_device_shape_;
+}
+
+int XRTTupleAllocation::device_ordinal() { return device_ordinal_; }
+
+const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() {
+ return buffers_.element({})->allocation();
+}
+
+/*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key,
+ XRTTupleAllocation** allocation) {
+ string key_string = strings::StrCat(key);
+ TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation));
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm,
+ int64 key) {
+ string key_string = strings::StrCat(key);
+ return rm->Delete<XRTTupleAllocation>(kTupleContainer, key_string);
+}
+
+// Helper typedef to make ShapeTree ForEach helper lambda signatures more
+// readable. They need a type of const T& where in this case T is the
+// following pointer.
+typedef XRTBufferAllocation* XRTBufferAllocationPtr;
+
+/*static*/ Status XRTTupleAllocation::MakeSubBuffer(
+ XRTTupleAllocation* parent, const xla::ShapeIndex& subshape,
+ XRTTupleAllocation** allocation, bool alias_parent_allocation) {
+ TF_ASSIGN_OR_RETURN(
+ const xla::Shape* host_sub_shape,
+ xla::ShapeUtil::TryGetSubshape(parent->on_host_shape(), subshape));
+ TF_ASSIGN_OR_RETURN(
+ const xla::Shape* device_sub_shape,
+ xla::ShapeUtil::TryGetSubshape(parent->on_device_shape(), subshape));
+
+ *allocation =
+ new XRTTupleAllocation(parent->device_ordinal(), parent->allocator_,
+ *host_sub_shape, *device_sub_shape);
+ if (alias_parent_allocation) {
+ // Copy the subtree of allocations from the parent allocation.
+ (*allocation)->buffers_.CopySubtreeFrom(parent->buffers_, subshape, {});
+ // Increment the refcount on each aliased buffer.
+ (*allocation)
+ ->buffers_.ForEachElement(
+ [](const xla::ShapeIndex& index,
+ const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
+ } else {
+ // Find the buffers in the parent allocation that match the subtree, and
+ // move the parent allocation's buffer over to the new allocation.
+ (*allocation)
+ ->buffers_.ForEachMutableElement(
+ [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
+ // Extend the allocation's index to the parent's frame by adding
+ // subshape as a prefix.
+ xla::ShapeIndex parent_index = subshape;
+ for (int i = 0; i < index.size(); ++i) {
+ parent_index.push_back(index[i]);
+ }
+ *buffer = parent->buffers_.element(parent_index);
+ *parent->buffers_.mutable_element(parent_index) =
+ new XRTBufferAllocation(se::DeviceMemoryBase(),
+ parent->device_ordinal(),
+ parent->allocator_);
+ });
+ }
+
+ return Status::OK();
+}
+
+/* static */ Status XRTTupleAllocation::ExpandTreeOfTuples(
+ const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
+ xla::Shape* device_shape) {
+ // Initialize both host and device shape to be the 'spine' of the new tuple
+ // shape, given by the shape of the tree of tuples.
+ *host_shape = elements.shape();
+ *device_shape = elements.shape();
+ // Now go over the leaves of the tree of tuples, and 'graft' the host/device
+ // shapes of the allocation at that leaf onto the expanded host/device shapes
+ // at the leaf position.
+ TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
+ [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
+ if (elements.IsLeaf(index)) {
+ if (element.allocation == nullptr) {
+ return errors::InvalidArgument(
+ "MakeTuple elements has a null internal node at index ",
+ index.ToString());
+ }
+ if (device_ordinal != element.allocation->device_ordinal() ||
+ allocator != element.allocation->allocator_) {
+ return errors::InvalidArgument(
+ "MakeTuple elements must all be allocated on the same device "
+ "as the destination.");
+ }
+ *xla::ShapeUtil::GetMutableSubshape(host_shape, index) =
+ element.allocation->on_host_shape();
+ *xla::ShapeUtil::GetMutableSubshape(device_shape, index) =
+ element.allocation->on_device_shape();
+ } else {
+ if (element.allocation != nullptr) {
+ return errors::InvalidArgument(
+ "MakeTuple elements has a non-null internal node at index ",
+ index.ToString());
+ }
+ }
+ return Status::OK();
+ }));
+ return Status::OK();
+}
+
+/*static*/ Status XRTTupleAllocation::MakeTuple(
+ xla::Backend* backend, int device_ordinal,
+ const xla::ShapeTree<ExpandedTupleInput>& elements,
+ XRTTupleAllocation** allocation) {
+ auto transfer_manager = backend->transfer_manager();
+ auto allocator = backend->memory_allocator();
+ TF_ASSIGN_OR_RETURN(auto stream, backend->BorrowStream(device_ordinal));
+
+ xla::Shape host_shape;
+ xla::Shape device_shape;
+ TF_RETURN_IF_ERROR(ExpandTreeOfTuples(elements, device_ordinal, allocator,
+ &host_shape, &device_shape));
+
+ // The aliasing is determined below based on whether or not all the inputs are
+ // released while being transferred. allocation_tmp is a local pointer that is
+ // copied to *allocation at the end only if the method succeeds.
+ auto allocation_tmp = new XRTTupleAllocation(device_ordinal, allocator,
+ host_shape, device_shape);
+ core::ScopedUnref allocation_unref(allocation_tmp);
+ // First allocate device memory for the new tuple index tables, one at each
+ // internal node of the elements tree. Do this in a separate pass into a
+ // ScopedShapedBuffer so that it's easy to free the newly-allocated memory if
+ // an allocation fails. Make sure the shape has layout so that the code that
+ // writes index tables will be happy lower down.
+ xla::Shape spine_shape = elements.shape();
+ xla::LayoutUtil::SetToDefaultLayout(&spine_shape);
+ auto new_tuple_buffers = absl::make_unique<xla::ScopedShapedBuffer>(
+ spine_shape, spine_shape, allocator, device_ordinal);
+ TF_RETURN_IF_ERROR(elements.ForEachElementWithStatus(
+ [&](const xla::ShapeIndex& index, const ExpandedTupleInput& element) {
+ if (!elements.IsLeaf(index)) {
+ xla::Shape subshape =
+ xla::ShapeUtil::GetSubshape(device_shape, index);
+ uint64 size = transfer_manager->GetByteSizeRequirement(subshape);
+ TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer,
+ allocator->Allocate(device_ordinal, size,
+ /*retry_on_failure=*/false));
+ VLOG(2) << "Allocated buffer at " << buffer.opaque() << " index "
+ << index.ToString();
+ // Move the new buffer into new_tuple_buffers, which takes ownership
+ // of it.
+ new_tuple_buffers->set_buffer(std::move(buffer), index);
+ }
+ return Status::OK();
+ }));
+ // Transfer from the ScopedShapedBuffer to a ShapedBuffer, which does not own
+ // the newly-allocated index tables. Right now there's no owner for the new
+ // index tables, so next we will transfer ownership to the new allocation,
+ // taking care not to return early on any errors in the meantime.
+ xla::ShapedBuffer tuple_buffers = new_tuple_buffers->release();
+ // Now fill in the remaining datastructures. After this ForEachElement
+ // completes:
+ // 1) Every leaf element of tuple_buffers will be the root buffer of
+ // an existing allocation, and every internal element of tuple_buffers
+ // will be a newly-allocated index table. tuple_buffers does not own any
+ // of these.
+ // 2) Every element of allocation_tmp->buffers_ will be a correctly
+ // constructed
+ // XRTBufferAllocation wrapping the necessary allocations. For buffers in
+ // existing allocations there will be a new reference owned by the new
+ // allocation, and for newly-allocated index tables there will be a
+ // single reference owned by the new allocation.
+ elements.ForEachElement([&](const xla::ShapeIndex& index,
+ const ExpandedTupleInput& element) {
+ if (elements.IsLeaf(index)) {
+ allocation_tmp->buffers_.CopySubtreeFrom(element.allocation->buffers_, {},
+ index);
+ tuple_buffers.set_buffer(element.allocation->root_allocation(), index);
+ if (element.release_allocation_after_use) {
+ // Transfer the references from element's buffers to the new allocation
+ // rather than incrementing the refcount. The caller should have
+ // validated that release_allocation_after_use is false if
+ // element.allocation appears in more than one leaf.
+ element.allocation->buffers_.ForEachMutableElement(
+ [&](const xla::ShapeIndex& index, XRTBufferAllocationPtr* buffer) {
+ *buffer = new XRTBufferAllocation(
+ se::DeviceMemoryBase(), element.allocation->device_ordinal(),
+ element.allocation->allocator_);
+ });
+ } else {
+ // Increment the refcount on each newly-aliased buffer.
+ element.allocation->buffers_.ForEachElement(
+ [](const xla::ShapeIndex& index,
+ const XRTBufferAllocationPtr& buffer) { buffer->Ref(); });
+ }
+ } else {
+ // This is an internal node of the tuple tree so take ownership of the
+ // newly-created index table.
+ *allocation_tmp->buffers_.mutable_element(index) =
+ new XRTBufferAllocation(tuple_buffers.buffer(index), device_ordinal,
+ allocator);
+ }
+ });
+ // Because the internal nodes of tuple_buffers are exactly the new index
+ // tables, WriteTupleIndexTables will write only the new index tables and not
+ // rewrite the index tables for the existing allocations.
+ TF_RETURN_IF_ERROR(
+ transfer_manager->WriteTupleIndexTables(stream.get(), tuple_buffers));
+
+ *allocation = allocation_tmp;
+ // Get another reference since allocation_tmp will be Unrefed automatically on
+ // exit.
+ (*allocation)->Ref();
+ return Status::OK();
+}
+
+Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) {
+ *key = get_uid();
+ string key_string = strings::StrCat(*key);
+ return rm->Create(kTupleContainer, key_string, this);
+}
+
+bool XRTTupleAllocation::IsExclusiveOwner() {
+ for (const auto& buffer : buffers_) {
+ if (!buffer.second->RefCountIsOne()) return false;
+ }
+ return true;
+}
+
+void XRTTupleAllocation::InitializeFromShapedBuffer(
+ const xla::ShapedBuffer& shaped_buffer,
+ xla::DeviceMemoryAllocator* allocator, int device_ordinal) {
+ for (auto& buffer : buffers_) {
+ // Make a reference-counted version of the allocated buffer.
+ buffer.second = new XRTBufferAllocation(shaped_buffer.buffer(buffer.first),
+ device_ordinal, allocator);
+ }
+}
+
+xla::ShapedBuffer XRTTupleAllocation::ToShapedBuffer() {
+ xla::ShapedBuffer shaped_buffer(on_host_shape(), on_device_shape(),
+ allocator_->platform(), device_ordinal_);
+ for (const auto& buffer : buffers_) {
+ shaped_buffer.set_buffer(buffer.second->allocation(), buffer.first);
+ }
+ return shaped_buffer;
+}
+
+xla::ShapeTree<xla::MaybeOwningDeviceMemory>
+XRTTupleAllocation::ToDeviceMemoryTree(bool release) {
+ xla::ShapeTree<xla::MaybeOwningDeviceMemory> shaped_tree(on_device_shape());
+ for (const auto& buffer : buffers_) {
+ if (!release) {
+ *shaped_tree.mutable_element(buffer.first) = buffer.second->allocation();
+ } else {
+ *shaped_tree.mutable_element(buffer.first) = xla::OwningDeviceMemory(
+ buffer.second->allocation(), device_ordinal_, allocator_);
+ DiscardAllocation(buffer.first);
+ }
+ }
+ return shaped_tree;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_state.h b/tensorflow/compiler/xrt/xrt_state.h
new file mode 100644
index 0000000000..42705688dd
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_state.h
@@ -0,0 +1,208 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Classes for keeping track of on-device state.
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_STATE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
+
+namespace tensorflow {
+
+// TODO(misard) make this a Tensor if and when that makes sense.
+// A reference-counted wrapper around a buffer allocation. This maps an XLA
+// tuple index or a non-tuple XLA shape to a region of device memory. The device
+// memory buffer is freed when the reference count drops to zero.
+class XRTBufferAllocation : public core::RefCounted {
+ public:
+ XRTBufferAllocation(const se::DeviceMemoryBase& allocation,
+ int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator);
+ ~XRTBufferAllocation() override;
+
+ // The region of device memory being wrapped.
+ const se::DeviceMemoryBase& allocation();
+
+ // Sets the DeviceMemoryBase to be null. DiscardAllocation should be called
+ // when ownership of the underlying buffer has been transferred, e.g., to an
+ // output buffer when input and output buffers are aliased during
+ // execution. The call to DiscardAllocation prevents any device buffer being
+ // freed when the reference count drops to zero.
+ void DiscardAllocation();
+
+ private:
+ se::DeviceMemoryBase allocation_;
+ int device_ordinal_;
+ xla::DeviceMemoryAllocator* allocator_;
+};
+
+// Entry in the resource manager corresponding to an allocation handle returned
+// to a client. The handle identifies an immutable tuple of data in device
+// memory. New handles can be created in three ways: by passing a literal in
+// which case device memory is allocated and the literal is transferred to that
+// memory; by aliasing a sub-shape of an existing tuple-shaped handle; or by
+// aliasing a vector of existing handles to create a new tuple. The underlying
+// storage is reference-counted. When a handle is released, the reference count
+// of each storage buffer is decremented, and buffers with no outstanding
+// references are freed.
+class XRTTupleAllocation : public ResourceBase {
+ public:
+ ~XRTTupleAllocation() override;
+
+ // Allocates new device memory buffers sufficient to store literal, transfers
+ // literal to that memory, and returns a XRTTupleAllocation handle to the
+ // allocated buffers.
+ static Status CreateAndTransfer(const xla::Literal& literal,
+ xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation);
+
+ // Wraps an existing ShapeBuffer in a new XRTTupleAllocation handle.
+ static Status CreateFromBuffer(const xla::ShapedBuffer& shaped_buffer,
+ xla::Backend* backend, int device_ordinal,
+ XRTTupleAllocation** allocation);
+
+ // Aliases a sub-shape of parent and returns a XRTTupleAllocation handle
+ // to the sub-shape. If alias_base_allocation is true, the buffers in the
+ // sub-shape will be shared between parent and the returned allocation,
+ // otherwise the overlapping buffers in parent will be replaced by
+ // nullptr.
+ static Status MakeSubBuffer(XRTTupleAllocation* parent,
+ const xla::ShapeIndex& subshape,
+ XRTTupleAllocation** allocation,
+ bool alias_parent_allocation);
+
+ // A structure describing a leaf of a tree of tuples to expand. Each leaf
+ // contains an allocation and indicates whether or not the allocation's handle
+ // should be freed after incorporating its buffers into the expanded tree.
+ struct ExpandedTupleInput {
+ XRTTupleAllocation* allocation;
+ bool release_allocation_after_use;
+ };
+
+ // Returns a handle to a new tuple where the subtree of the new tuple at an
+ // index corresponding to a leaf of 'elements' is constructed from the
+ // allocation (i.e., a tuple or array) pointed to by that leaf. If
+ // release_allocation_after_use is false at a leaf, the new tuple will alias
+ // the input allocation at that leaf, otherwise the input allocation will be
+ // released. Input allocations may be repeated (appear in more than one leaf)
+ // in which case the corresponding buffers in the output tuple will alias. If
+ // an input is repeated, release_input_handle must be false for every leaf
+ // where that input appears. The latter property is not validated by MakeTuple
+ // and must be enforced by the caller.
+ static Status MakeTuple(xla::Backend* backend, int device_ordinal,
+ const xla::ShapeTree<ExpandedTupleInput>& elements,
+ XRTTupleAllocation** allocation);
+
+ // Retrieves the allocation interned under key from rm. The caller owns a
+ // reference to allocation after looking it up.
+ static Status Lookup(ResourceMgr* rm, int64 key,
+ XRTTupleAllocation** allocation);
+
+ // Deletes the reference in the rm to an allocation interned under key.
+ static Status DeleteFromResourceManager(ResourceMgr* rm, int64 key);
+
+ // Adds the allocation to a ResourceMgr and returns the key that will be used
+ // to retrieve it. Transfers a reference on *this to rm.
+ Status Intern(ResourceMgr* rm, int64* key);
+
+ // Copies the allocation from device to host and returns it in literal.
+ Status ToLiteral(xla::Backend* backend, int device_ordinal,
+ std::unique_ptr<xla::Literal>* literal);
+
+ // True if none of the buffers in the allocation are aliased by any other live
+ // handle.
+ bool IsExclusiveOwner();
+
+ // The ordinal of the device holding this tuple.
+ int device_ordinal();
+
+ // Returns the shape of the tuple as seen by the host.
+ const xla::Shape& on_host_shape();
+
+ // Returns the shape of the tuple as stored on the device.
+ const xla::Shape& on_device_shape();
+
+ // Returns the buffer pointed to by the root of the tuple.
+ const se::DeviceMemoryBase& root_allocation();
+
+ // Stops managing the storage for the allocation at buffer_index, e.g.,
+ // because it has been aliased to the output buffer of a computation.
+ void DiscardAllocation(const xla::ShapeIndex& buffer_index);
+
+ // Returns the tree of allocations as a ShapedBuffer. This tree may not have
+ // the same shape as on_host_shape.
+ xla::ShapedBuffer ToShapedBuffer();
+
+ // Returns the device memory tree of this allocation. If 'release' is set, the
+ // ownership of the device memory is transferred to the result.
+ xla::ShapeTree<xla::MaybeOwningDeviceMemory> ToDeviceMemoryTree(bool release);
+
+ string DebugString() override { return "XLA allocation handle"; }
+
+ private:
+ // Creates a new handle with (tuple) shape.
+ XRTTupleAllocation(int device_ordinal, xla::DeviceMemoryAllocator* allocator,
+ const xla::Shape& on_host_shape,
+ const xla::Shape& on_device_shape);
+
+ // Inherits the allocations represented in buffer, which must have the same
+ // shape as buffers_.
+ void InitializeFromShapedBuffer(const xla::ShapedBuffer& shaped_buffer,
+ xla::DeviceMemoryAllocator* allocator,
+ int device_ordinal);
+
+ // Takes a tree 'elements' where each leaf is an allocation, validates that
+ // they are all on device_ordinal managed by allocator, and returns in
+ // host_shape and device_shape the host/device shapes of the expanded tree,
+ // where at each leaf of elements the shape of the allocation at elements is
+ // grafted on.
+ static Status ExpandTreeOfTuples(
+ const xla::ShapeTree<ExpandedTupleInput>& elements, int device_ordinal,
+ xla::DeviceMemoryAllocator* allocator, xla::Shape* host_shape,
+ xla::Shape* device_shape);
+
+ // Location of the memory that is being managed.
+ int device_ordinal_;
+ xla::DeviceMemoryAllocator* allocator_;
+
+ // The shape that the caller thinks the tuple has.
+ const xla::Shape on_host_shape_;
+ // The shape that the tuple has on device. Store this explicitly instead of
+ // using a shape stored in ShapeTree because ShapeTree discards the layout.
+ const xla::Shape on_device_shape_;
+ // The tree of reference-counted buffers, which uses on_device_shape_ as its
+ // shape.
+ xla::ShapeTree<XRTBufferAllocation*> buffers_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_STATE_H_