aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-20 15:37:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 15:45:36 -0700
commit4d39844c1dafb6b74ad49b231bc949a2e026f5ea (patch)
tree0f3043c6551f6be17026ceafe62b8dcffc026da8 /tensorflow/compiler/jit
parentd78b3484d4b98790c2d3a7c0d861487e2fcdefdf (diff)
Split XlaLaunch into XlaCompile and XlaRun; NFC
This CL splits the functionality in XlaLaunch into two separate operations: - XlaCompile, responsible for compiling a TF function into a LocalExecutable - XlaRun, responsible for executing a LocalExecutable created by XlaCompile This CL is a stepping stone towards implementing lazy compilation for TF/XLA. The XlaCompile op is spec'ed to return a boolean indicating whether the compilation was successful. Right now that boolean is always set to true by XlaCompile and its value is otherwise ignored, but in the future it will be used to indicate whether the TF function was compiled or not, and thus whether we should execute XlaRun or just directly call the TF function. XlaLaunch still exists, and will be created by create_xla_launch_op.cc. In the future we may consider removing it altogether. build_xla_launch_ops.cc, now renamed to build_xla_ops.cc, creates a XlaCompile/XlaRun pair instead of XlaLaunch. This CL is organized as follows: - jit/ops/xla_ops.cc gets two new XLA-specific operations, XlaCompile and XlaRun, described above. XlaRun redundantly takes the must-be-constant inputs to the TensorFlow cluster to keep the implementation simple (simple in the sense of similar to XlaLaunch), but I will remove this in a subsequent cleanup CL. - jit/kernels/xla_ops.cc implements XlaCompile and XlaRun in a fairly straightforward manner. XlaCompile compiles the TF function, puts it in a process-global storage, XlaExecutableClosureStore, and produces a int64 key. XlaRun uses the key to read out the LocalExecutable and execute it. I'm not sure if XlaExecutableClosureStore should be a resource like XlaCompilationCache; I did not immediately see any reason to make it so. - There are changes to the various _device files to register XlaCompile and XlaRun for the XLA_* devices. - Finally, I had to fix some tests that were expecting XlaLaunch in the execution timeline. PiperOrigin-RevId: 213895405
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD24
-rw-r--r--tensorflow/compiler/jit/build_xla_launch_ops_pass.cc142
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc187
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.h (renamed from tensorflow/compiler/jit/build_xla_launch_ops_pass.h)10
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc4
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD7
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc276
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h87
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.cc488
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.h168
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc43
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc5
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h11
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc5
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc6
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc2
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h3
18 files changed, 938 insertions, 532 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 9544c365b7..4e184729ef 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -51,7 +51,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
@@ -63,7 +63,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda([
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
]),
@@ -77,7 +77,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -95,7 +95,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
@@ -112,7 +112,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
@@ -281,7 +281,7 @@ cc_library(
deps = [
":common",
":compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -342,7 +342,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -360,7 +360,7 @@ tf_cc_test(
cc_library(
name = "compilation_passes",
srcs = [
- "build_xla_launch_ops_pass.cc",
+ "build_xla_ops_pass.cc",
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
@@ -370,7 +370,7 @@ cc_library(
"partially_decluster_pass.cc",
],
hdrs = [
- "build_xla_launch_ops_pass.h",
+ "build_xla_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
@@ -460,7 +460,7 @@ tf_cc_test(
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -494,7 +494,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
@@ -525,7 +525,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
deleted file mode 100644
index b17ff589e2..0000000000
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-static Status BuildLaunchNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
- Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("XlaLaunch");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
-}
-
-static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
- VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
-
- int num_constant_args, num_resource_args;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
-
- if (num_constant_args < 0 || num_resource_args < 0 ||
- num_constant_args + num_resource_args > node->num_inputs()) {
- return errors::InvalidArgument(
- "Invalid number of constant/resource arguments to XLA kernel.");
- }
- const int num_nonconst_args =
- node->num_inputs() - num_constant_args - num_resource_args;
-
- DataTypeVector const_dtypes(node->input_types().begin(),
- node->input_types().begin() + num_constant_args);
- DataTypeVector arg_dtypes(
- node->input_types().begin() + num_constant_args,
- node->input_types().begin() + num_constant_args + num_nonconst_args);
-
- // Build a XlaLaunch operator to execute the function body.
- Node* launch_node;
- TF_RETURN_IF_ERROR(BuildLaunchNode(
- graph->NewName(node->name()), node->type_string(), node->def().attr(),
- node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
- node->output_types(), graph, &launch_node));
- launch_node->set_assigned_device_name(node->assigned_device_name());
-
- // Copy incoming edges to the launch node.
- for (const Edge* edge : node->in_edges()) {
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(edge->src(), launch_node);
- } else {
- graph->AddEdge(edge->src(), edge->src_output(), launch_node,
- edge->dst_input());
- }
- }
-
- // Copy outgoing edges to the launch node.
- std::vector<const Edge*> out_edges(node->out_edges().begin(),
- node->out_edges().end());
- for (const Edge* edge : out_edges) {
- Node* dst = edge->dst();
- int src_output = edge->src_output();
- int dst_input = edge->dst_input();
- graph->RemoveEdge(edge);
-
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(launch_node, dst);
- } else {
- graph->AddEdge(launch_node, src_output, dst, dst_input);
- }
- }
- graph->RemoveNode(node);
-
- return Status::OK();
-}
-
-Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
- Graph* graph = options.graph->get();
-
- for (Node* n : graph->op_nodes()) {
- // In all cases, only try to compile computational nodes.
- if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
- continue;
- }
-
- // Only compile nodes that are marked for compilation by the
- // compilation-marking pass (via 'attr_name').
- if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
- }
- }
-
- if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
- options.flib_def);
- }
- return Status::OK();
-}
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
new file mode 100644
index 0000000000..a6086f30a1
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -0,0 +1,187 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+static Status BuildXlaCompileNode(
+ const string& nodename, const string& function_name,
+ const AttrValueMap& function_attr, const string& device_name,
+ const DataTypeVector& constant_dtypes, int num_resources,
+ const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaCompile");
+ def.set_device(device_name);
+ AddNodeAttr("Tconstants", constant_dtypes, &def);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Nresources", num_resources, &def);
+ NameAttrList function;
+ function.set_name(function_name);
+ *function.mutable_attr() = function_attr;
+ AddNodeAttr("function", function, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status BuildXlaRunNode(const string& nodename, const string& device_name,
+ const DataTypeVector& constant_dtypes,
+ const DataTypeVector& arg_dtypes,
+ const DataTypeVector& result_dtypes, Graph* graph,
+ Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaRun");
+ def.set_device(device_name);
+ AddNodeAttr("Tconstants", constant_dtypes, &def);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Tresults", result_dtypes, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status GetXlaAttrs(Node* node, int* num_constant_args,
+ int* num_resource_args, DataTypeVector* const_dtypes,
+ DataTypeVector* arg_dtypes) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+
+ if (*num_constant_args < 0 || *num_resource_args < 0 ||
+ *num_constant_args + *num_resource_args > node->num_inputs()) {
+ return errors::InvalidArgument(
+ "Invalid number of constant/resource arguments to XLA kernel.");
+ }
+
+ const int num_nonconst_args =
+ node->num_inputs() - *num_constant_args - *num_resource_args;
+
+ const DataTypeVector& input_types = node->input_types();
+ std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
+ std::back_inserter(*const_dtypes));
+ std::copy(input_types.begin() + *num_constant_args,
+ input_types.begin() + *num_constant_args + num_nonconst_args,
+ std::back_inserter(*arg_dtypes));
+ return Status::OK();
+}
+
+static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node) {
+ for (const Edge* edge : old_node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(edge->src(), new_node);
+ } else {
+ g->AddEdge(edge->src(), edge->src_output(), new_node, edge->dst_input());
+ }
+ }
+}
+
+static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ Node* dst = edge->dst();
+ int src_output = edge->src_output();
+ int dst_input = edge->dst_input();
+ g->RemoveEdge(edge);
+
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(new_node, dst);
+ } else {
+ g->AddEdge(new_node, src_output, dst, dst_input);
+ }
+ }
+}
+
+static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
+ int num_constant_args, num_resource_args;
+ DataTypeVector const_dtypes;
+ DataTypeVector arg_dtypes;
+
+ TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
+ &const_dtypes, &arg_dtypes));
+
+ Node *compile_node, *run_node;
+
+ TF_RETURN_IF_ERROR(BuildXlaCompileNode(
+ n->name(), n->type_string(), n->def().attr(), n->requested_device(),
+ const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+
+ DataTypeVector arg_dtypes_with_resources = arg_dtypes;
+ for (int i = 0; i < num_resource_args; i++) {
+ arg_dtypes_with_resources.push_back(DT_RESOURCE);
+ }
+
+ TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
+ const_dtypes, arg_dtypes_with_resources,
+ n->output_types(), g, &run_node));
+
+ compile_node->set_assigned_device_name(n->assigned_device_name());
+ run_node->set_assigned_device_name(n->assigned_device_name());
+
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node);
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+
+ // The compilation_key output.
+ g->AddEdge(compile_node, 0, run_node, n->num_inputs());
+
+ MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+
+ for (Node* n : graph->op_nodes()) {
+ // In all cases, only try to compile computational nodes.
+ if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
+ continue;
+ }
+
+ // Only compile nodes that are marked for compilation by the
+ // compilation-marking pass (via 'attr_name').
+ if (IsXlaCompiledKernel(*n)) {
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ }
+ }
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
+ }
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h
index 1dfea93f02..1dd38fa951 100644
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.h
@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
-#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
+// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
+// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
+class BuildXlaOpsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 56b034a30b..6f1ff85f24 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 3770eea6d0..085c0e5adb 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -55,6 +55,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
- BuildXlaLaunchOpsPass);
+ BuildXlaOpsPass);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 253a5d2547..0839f1cb3d 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -7,9 +7,9 @@ package(
)
cc_library(
- name = "xla_launch_op",
- srcs = ["xla_launch_op.cc"],
- hdrs = ["xla_launch_op.h"],
+ name = "xla_ops",
+ srcs = ["xla_ops.cc"],
+ hdrs = ["xla_ops.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
@@ -26,6 +26,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
deleted file mode 100644
index b6f2f632f7..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ /dev/null
@@ -1,276 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
-
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/variable_ops.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function)
- : OpKernel(ctx),
- constants_(constants),
- resources_(resources),
- device_type_(ctx->device_type()),
- function_(function) {
- if (device_type_ == DeviceType(DEVICE_CPU)) {
- platform_id_ = se::host::kHostPlatformId;
- } else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = ctx->device()
- ->tensorflow_gpu_device_info()
- ->stream->parent()
- ->platform()
- ->id();
- } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
- use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
- platform_id_ = xla_device_metadata_->platform()->id();
- }
-}
-
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
- if (xla_device_metadata_) {
- *cache = new XlaCompilationCache(xla_device_metadata_->client(),
- xla_device_metadata_->jit_device_type());
- return Status::OK();
- }
-
- auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
- if (!platform.ok()) {
- return platform.status();
- }
- xla::LocalClientOptions client_options;
- client_options.set_platform(platform.ValueOrDie());
- client_options.set_intra_op_parallelism_threads(
- ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
- auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
- if (!client.ok()) {
- return client.status();
- }
- const XlaOpRegistry::DeviceRegistration* registration;
- if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
- &registration)) {
- return errors::InvalidArgument("No JIT device registered for ",
- device_type_.type());
- }
- *cache = new XlaCompilationCache(
- client.ValueOrDie(), DeviceType(registration->compilation_device_name));
- return Status::OK();
-}
-
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOpBase::Compute "
- << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
- // We store information about the JIT-compiled XLA computation
- // in the ResourceMgr.
- ResourceMgr* rm = ctx->resource_manager();
- OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
-
- se::Stream* stream =
- ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
-
- XlaCompilationCache* cache;
- OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
- rm->default_container(), "xla_cache", &cache,
- [this, ctx](XlaCompilationCache** cache) {
- return BuildCompilationCache(ctx, cache);
- }));
- // Hold the reference to the JIT during evaluation. (We could probably
- // free it sooner because the ResourceMgr will retain a reference, but
- // this is more obviously correct.)
- core::ScopedUnref cache_ref(cache);
-
- std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, resources_);
-
- xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
-
- XlaAllocator local_xla_allocator(client->backend().platform(),
- ctx->device()->GetAllocator({}));
- xla::DeviceMemoryAllocator* xla_allocator;
- // If we are on an XlaDevice, use the underlying XLA platform's allocator
- // directly. We could use the StreamExecutor's allocator which may
- // theoretically be more correct, but XLA returns a nice OOM message in a
- // Status and StreamExecutor does not.
- //
- // Importantly we can't use ctx->device()->GetAllocator() as the allocator
- // (which local_xla_allocator above uses) as on an XlaDevice, this is a
- // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
- // real allocator to allocate real buffers.
- if (xla_device_metadata_) {
- xla_allocator = client->backend().memory_allocator();
- } else {
- xla_allocator = &local_xla_allocator;
- }
-
- XlaCompiler::Options options;
- options.client = client;
- if (ctx->op_device_context() != nullptr) {
- options.device_ordinal =
- ctx->op_device_context()->stream()->parent()->device_ordinal();
- }
- options.device_type = cache->device_type();
- options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
- options.graph_def_version = ctx->function_library()->graph_def_version();
- options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
- options.device_allocator = xla_allocator;
- if (xla_device_metadata_) {
- options.shape_representation_fn =
- xla_device_metadata_->shape_representation_fn();
- }
-
- const XlaCompiler::CompilationResult* kernel;
- xla::LocalExecutable* executable;
-
- std::map<int, Tensor> constant_args;
- for (int i : constants_) {
- constant_args.insert({i, ctx->input(i)});
- }
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = true;
- // If we resolve constants we never emit them on the device, meaning that if
- // they are needed by a following computation the host has to transfer
- // them. Not resolving constants is expected to be faster than resolving
- // constants.
- compile_options.resolve_compile_time_constants = true;
- // Optimization: where possible, have the computation return a naked array
- // rather than a one-element tuple.
- compile_options.always_return_tuple = false;
-
- OP_REQUIRES_OK(
- ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, compile_options));
-
- VLOG(1) << "Executing XLA Computation...";
-
- XlaComputationLaunchContext launch_context(
- client, xla_allocator,
- /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
- use_multiple_streams_);
- launch_context.PopulateInputs(ctx, kernel, variables);
-
- // Execute the computation.
- VLOG(2) << "Executing computation.";
- xla::ExecutableRunOptions run_options;
- run_options.set_stream(stream);
- run_options.set_allocator(xla_allocator);
- run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- run_options.set_rng_seed(GetXLARandomSeed());
- Env* env = Env::Default();
- auto start_time = env->NowMicros();
-
- auto run_result = executable->Run(launch_context.arguments(), run_options);
- OP_REQUIRES(ctx, run_result.ok(), run_result.status());
-
- auto elapsed = env->NowMicros() - start_time;
- VLOG(2) << "Elapsed time: " << elapsed << "us";
-
- OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
- ctx, kernel, run_result.ConsumeValueOrDie()));
- VLOG(1) << "Done";
-}
-
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
- return RET; \
- } \
- } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
- std::vector<int> constants(constant_types.size());
- std::iota(constants.begin(), constants.end(), 0);
- return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
-
- DataTypeVector arg_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Targs", &arg_types));
-
- int num_resources;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Nresources", &num_resources));
-
- std::vector<int> resources(num_resources);
- std::iota(resources.begin(), resources.end(),
- constant_types.size() + arg_types.size());
- return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
- const NameAttrList* func;
- OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
- return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-} // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
- FunctionAttr(ctx)) {}
-
-XlaLocalLaunchOp::~XlaLocalLaunchOp() {
- VLOG(1) << "XlaLocalLaunchOp destroyed";
-}
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
- .Device(DEVICE_GPU)
- .HostMemory("constants")
- .HostMemory("resources"),
- XlaLocalLaunchOp);
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
deleted file mode 100644
index e0f10e9817..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/jit/xla_device.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
- XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function);
- XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
- XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
- ~XlaLocalLaunchBase() override = default;
-
- void Compute(OpKernelContext* ctx) override;
-
- protected:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache);
-
- // Indexes of compile-time constant inputs
- std::vector<int> constants_;
- // Indexes of resource inputs
- std::vector<int> resources_;
-
- DeviceType device_type_;
- NameAttrList function_;
- se::Platform::Id platform_id_ = nullptr;
- bool use_multiple_streams_ = false;
- const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
-};
-
-// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
-// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
-// responsible for handling interactions with the TensorFlow executor.
-// Once all inputs are present, and their shapes are known, the op can
-// use a 'XlaCompilationCache' to compile and execute code which is specific
-// to the shapes of input Tensors.
-// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
-// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
-// memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
- public:
- explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
- ~XlaLocalLaunchOp() override;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
new file mode 100644
index 0000000000..c483841a7c
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -0,0 +1,488 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status PlatformInfoFromContext(OpKernelConstruction* ctx,
+ XlaPlatformInfo* result) {
+ DeviceType device_type = ctx->device_type();
+ se::Platform::Id platform_id = nullptr;
+ const XlaDevice::Metadata* xla_device_metadata = nullptr;
+ std::unique_ptr<XlaAllocator> xla_allocator;
+ xla::DeviceMemoryAllocator* device_allocator = nullptr;
+
+ if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+ platform_id = se::host::kHostPlatformId;
+ } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
+ platform_id = ctx->device()
+ ->tensorflow_gpu_device_info()
+ ->stream->parent()
+ ->platform()
+ ->id();
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
+ // If we are on an XlaDevice, use the underlying XLA platform's allocator
+ // directly. We could use the StreamExecutor's allocator which may
+ // theoretically be more correct, but XLA returns a nice OOM message in a
+ // Status and StreamExecutor does not.
+ //
+ // Importantly we can't use ctx->device()->GetAllocator() as the allocator
+ // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
+ // allocator that returns XlaTensor objects. The XlaCompiler needs a real
+ // allocator to allocate real buffers.
+
+ platform_id = xla_device_metadata->platform()->id();
+ device_allocator =
+ xla_device_metadata->client()->backend().memory_allocator();
+ }
+
+ if (!device_allocator) {
+ TF_ASSIGN_OR_RETURN(se::Platform* const platform,
+ se::MultiPlatformManager::PlatformWithId(platform_id));
+ xla_allocator = absl::make_unique<XlaAllocator>(
+ platform, ctx->device()->GetAllocator({}));
+ }
+
+ *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
+ std::move(xla_allocator), device_allocator);
+
+ return Status::OK();
+}
+
+// A closure describing how to run a compiled version of a TensorFlow function.
+//
+// It may seem unusual to stick the resource variable snapshots in this class.
+// This is necessary: we need to use the snapshots observed by the compiler as
+// the initial values for the resource variables (and cannot snapshot them again
+// during execution) because otherwise we risk observing a different snapshot
+// with shapes different from what we compiled for.
+class XlaExecutableClosure {
+ public:
+ explicit XlaExecutableClosure(
+ xla::LocalClient* client, xla::LocalExecutable* executable,
+ const XlaCompiler::CompilationResult* compilation_result,
+ std::map<int, OptionalTensor> resource_var_snapshots)
+ : client_(client),
+ executable_(executable),
+ compilation_result_(compilation_result),
+ resource_var_snapshots_(std::move(resource_var_snapshots)) {}
+
+ XlaExecutableClosure(XlaExecutableClosure&&) = default;
+ XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
+
+ xla::LocalClient* client() const { return client_; }
+ xla::LocalExecutable* executable() const { return executable_; }
+ const XlaCompiler::CompilationResult* compilation_result() const {
+ return compilation_result_;
+ }
+ const std::map<int, OptionalTensor>& resource_var_snapshots() const {
+ return resource_var_snapshots_;
+ }
+
+ private:
+ xla::LocalClient* client_;
+ xla::LocalExecutable* executable_;
+ const XlaCompiler::CompilationResult* compilation_result_;
+ std::map<int, OptionalTensor> resource_var_snapshots_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
+};
+
+// This maintains a mapping from a globally unique ID to XlaExecutableClosure
+// instances.
+class XlaExecutableClosureStore {
+ public:
+ XlaExecutableClosureStore() : key_counter_(0) {}
+
+ using KeyT = string;
+
+ KeyT Produce(XlaExecutableClosure result) {
+ mutex_lock l(mutex_);
+ KeyT key = absl::StrCat(key_counter_++);
+ bool insert_successful = closures_.emplace(key, std::move(result)).second;
+ DCHECK(insert_successful);
+ (void)insert_successful;
+ return key;
+ }
+
+ XlaExecutableClosure Consume(const KeyT& key) {
+ mutex_lock l(mutex_);
+ auto it = closures_.find(key);
+ DCHECK(it != closures_.end());
+ XlaExecutableClosure value = std::move(it->second);
+ closures_.erase(it);
+ return value;
+ }
+
+ static XlaExecutableClosureStore* Global() {
+ static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
+ return instance;
+ }
+
+ private:
+ mutex mutex_;
+ int64 key_counter_ GUARDED_BY(mutex_);
+ gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
+};
+
+} // namespace
+
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ function_(function) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+static Status BuildCompilationCache(OpKernelContext* ctx,
+ const XlaPlatformInfo& platform_info,
+ XlaCompilationCache** cache) {
+ if (platform_info.xla_device_metadata()) {
+ *cache = new XlaCompilationCache(
+ platform_info.xla_device_metadata()->client(),
+ platform_info.xla_device_metadata()->jit_device_type());
+ return Status::OK();
+ }
+
+ auto platform =
+ se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
+ if (!platform.ok()) {
+ return platform.status();
+ }
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform.ValueOrDie());
+ client_options.set_intra_op_parallelism_threads(
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+ auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
+ if (!client.ok()) {
+ return client.status();
+ }
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
+ &registration)) {
+ return errors::InvalidArgument("No JIT device registered for ",
+ platform_info.device_type().type());
+ }
+ *cache = new XlaCompilationCache(
+ client.ValueOrDie(), DeviceType(registration->compilation_device_name));
+ return Status::OK();
+}
+
+static Status CompileToLocalExecutable(
+ OpKernelContext* ctx, const NameAttrList& function,
+ const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
+ absl::Span<const int> constants, xla::LocalClient** client,
+ std::map<int, OptionalTensor>* variables,
+ const XlaCompiler::CompilationResult** kernel,
+ xla::LocalExecutable** executable) {
+ // We store information about the JIT-compiled XLA computation
+ // in the ResourceMgr.
+ ResourceMgr* rm = ctx->resource_manager();
+ if (!rm) {
+ return errors::Internal("No resource manager.");
+ }
+
+ XlaCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
+ rm->default_container(), "xla_cache", &cache,
+ [&](XlaCompilationCache** cache) {
+ return BuildCompilationCache(ctx, platform_info, cache);
+ }));
+ // Hold the reference to the JIT during evaluation. (We could probably
+ // free it sooner because the ResourceMgr will retain a reference, but
+ // this is more obviously correct.)
+ core::ScopedUnref cache_ref(cache);
+
+ *variables = SnapshotResourceVariables(ctx, resources);
+ *client = static_cast<xla::LocalClient*>(cache->client());
+
+ XlaCompiler::Options options;
+ options.client = *client;
+ if (ctx->op_device_context() != nullptr) {
+ options.device_ordinal =
+ ctx->op_device_context()->stream()->parent()->device_ordinal();
+ }
+ options.device_type = cache->device_type();
+ options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ options.graph_def_version = ctx->function_library()->graph_def_version();
+ options.allow_cpu_custom_calls =
+ (platform_info.platform_id() == se::host::kHostPlatformId);
+ options.device_allocator = platform_info.allocator();
+ if (platform_info.xla_device_metadata()) {
+ options.shape_representation_fn =
+ platform_info.xla_device_metadata()->shape_representation_fn();
+ }
+
+ std::map<int, Tensor> constant_args;
+ for (int i : constants) {
+ constant_args.insert({i, ctx->input(i)});
+ }
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.is_entry_computation = true;
+ // If we resolve constants we never emit them on the device, meaning that if
+ // they are needed by a following computation the host has to transfer
+ // them. Not resolving constants is expected to be faster than resolving
+ // constants.
+ compile_options.resolve_compile_time_constants = true;
+ // Optimization: where possible, have the computation return a naked array
+ // rather than a one-element tuple.
+ compile_options.always_return_tuple = false;
+
+ return cache->Compile(options, function, constant_args, *variables, ctx,
+ kernel, executable, compile_options);
+}
+
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+ << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
+
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+ VLOG(1) << "Executing XLA Computation...";
+
+ XlaComputationLaunchContext launch_context(
+ client, platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ platform_info_.UseMultipleStreams());
+ launch_context.PopulateInputs(ctx, kernel, variables);
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result = executable->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie()));
+ VLOG(1) << "Done";
+}
+
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
+XlaLocalLaunchOp::~XlaLocalLaunchOp() {
+ VLOG(1) << "XlaLocalLaunchOp destroyed";
+}
+
+XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx),
+ constants_(ConstantsVector(ctx)),
+ resources_(ResourcesVector(ctx)),
+ function_(FunctionAttr(ctx)) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaCompileOp::Compute(OpKernelContext* ctx) {
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
+ // if it didn't have to compile the cluster because of a compilation-cache
+ // hit. This is because we at least need new snapshots of the resource
+ // variables.
+ XlaExecutableClosureStore::KeyT key =
+ XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
+ client, executable, kernel, std::move(variables)));
+
+ Allocator* cpu_allocator = [&] {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ return ctx->device()->GetAllocator(host_alloc_attrs);
+ }();
+
+ Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
+ compilation_key.flat<string>()(0) = key;
+
+ Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
+ compilation_successful.flat<bool>()(0) = true;
+
+ ctx->set_output(0, compilation_key);
+ ctx->set_output(1, compilation_successful);
+}
+
+XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaRunOp::Compute(OpKernelContext* ctx) {
+ Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
+ const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
+
+ XlaExecutableClosure closure =
+ XlaExecutableClosureStore::Global()->Consume(key);
+
+ XlaComputationLaunchContext launch_context(
+ closure.client(), platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
+ launch_context.PopulateInputs(ctx, closure.compilation_result(),
+ closure.resource_var_snapshots());
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result =
+ closure.executable()->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
+
+ OP_REQUIRES_OK(
+ ctx, launch_context.PopulateOutputs(ctx, closure.compilation_result(),
+ run_result.ConsumeValueOrDie()));
+}
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
+
+REGISTER_KERNEL_BUILDER(
+ Name("_XlaRun").Device(DEVICE_GPU).HostMemory("constants"), XlaRunOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h
new file mode 100644
index 0000000000..489d26eb30
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.h
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+// Holds some information about the platform on which an
+// XlaLaunch/_XlaCompile/_XlaRun op must run on.
+class XlaPlatformInfo {
+ public:
+ XlaPlatformInfo() : device_type_("") {}
+ explicit XlaPlatformInfo(const DeviceType device_type,
+ se::Platform::Id platform_id,
+ const XlaDevice::Metadata* xla_device_metadata,
+ std::unique_ptr<XlaAllocator> xla_allocator,
+ xla::DeviceMemoryAllocator* device_allocator)
+ : device_type_(device_type),
+ platform_id_(platform_id),
+ xla_device_metadata_(xla_device_metadata),
+ xla_allocator_(std::move(xla_allocator)),
+ device_allocator_(device_allocator) {
+ CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
+ }
+
+ XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
+
+ bool UseMultipleStreams() const {
+ return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
+ }
+
+ xla::DeviceMemoryAllocator* allocator() const {
+ return device_allocator_ ? device_allocator_ : xla_allocator_.get();
+ }
+ DeviceType device_type() const { return device_type_; }
+
+ // This is equal to xla_device_metadata()->platform()->id() if
+ // xla_device_metadata() is not nullptr.
+ se::Platform::Id platform_id() const { return platform_id_; }
+
+ // This may be null if the op this XlaPlatformInfo is for was not placed on an
+ // XLA device.
+ const XlaDevice::Metadata* xla_device_metadata() const {
+ return xla_device_metadata_;
+ }
+ bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
+
+ private:
+ DeviceType device_type_;
+ se::Platform::Id platform_id_;
+
+ // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
+ // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
+ // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
+ const XlaDevice::Metadata* xla_device_metadata_;
+
+ // If the op associated with this XlaPlatformInfo is placed on an XLA device
+ // then device_allocator_ is the xla::Backend's memory allocator and
+ // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
+ // then device_allocator_ is null and xla_allocator_ points to an appropriate
+ // XlaAllocator instance.
+ std::unique_ptr<XlaAllocator> xla_allocator_;
+ xla::DeviceMemoryAllocator* device_allocator_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
+};
+
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+ XlaPlatformInfo platform_info_;
+};
+
+// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'XlaCompilationCache' to compile and execute code which is specific
+// to the shapes of input Tensors.
+// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
+// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
+// memory.
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+ public:
+ explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
+ ~XlaLocalLaunchOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
+};
+
+class XlaCompileOp : public OpKernel {
+ public:
+ explicit XlaCompileOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+
+ XlaPlatformInfo platform_info_;
+};
+
+class XlaRunOp : public OpKernel {
+ public:
+ explicit XlaRunOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ XlaPlatformInfo platform_info_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 1a29c3caab..6b4cdaa1c1 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -51,4 +51,47 @@ REGISTER_OP("XlaClusterOutput")
"Operator that connects the output of an XLA computation to other "
"consumer graph nodes.");
+REGISTER_OP("_XlaCompile")
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Input("resources: Nresources * resource")
+ .Attr("Nresources: int >= 0")
+ .Output("key: string")
+ .Output("compilation_successful: bool")
+ .Attr("function: func")
+ // The compilation cache is stateful.
+ .SetIsStateful()
+ .Doc(R"(XLA Compile Op. For use by the XLA JIT only.
+
+Compiles a TensorFlow function into an XLA LocalExecutable and returns a key
+that _XlaRun can use to look up the LocalExecutable and execute it.
+
+key: A key that can be used to look up the local executable compiled by the
+ node and associated metadata.
+
+compilation_successful: True iff the compilation was successful. Always true
+for now.
+)");
+
+REGISTER_OP("_XlaRun")
+ // TODO(sanjoy): We don't need constants and Tconstants and they should be
+ // removed.
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Output("results: Tresults")
+ .Attr("Tresults: list(type) >= 0")
+ .Input("key: string")
+ // XLA random-number generation ops are stateful.
+ // TODO(phawkins): create stateful and non-stateful variants of _XlaRun.
+ .SetIsStateful()
+ .Doc(R"(XLA Run Op. For use by the XLA JIT only.
+
+Executes a TensorFlow function previously compiled into a LocalExecutable by an
+_XlaCompile op.
+)");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index e26fa27b31..003c1d8081 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "Host" (CPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
@@ -70,6 +70,9 @@ constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 49c8582682..639243973c 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -65,6 +65,17 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.HostMemory("resources"), \
KERNEL);
+#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \
+ .Device(DEVICE) \
+ .HostMemory("constants") \
+ .HostMemory("resources"), \
+ KERNEL);
+
+#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_XlaRun").Device(DEVICE).HostMemory("constants"), KERNEL);
+
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index c386984930..60979556a3 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -79,6 +79,9 @@ constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4574559674..19e681af0c 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -15,7 +15,7 @@ limitations under the License.
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -72,6 +72,10 @@ static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
kExecAllTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
+ kExecAllTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 5f2f6801e7..07a93e9c39 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -42,7 +42,7 @@ using xla::ShapedBuffer;
} // anonymous namespace
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables) {
+ OpKernelContext* ctx, absl::Span<const int> variables) {
std::map<int, OptionalTensor> snapshot;
for (int i : variables) {
Var* variable = nullptr;
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 7ac275fab8..fa7a5e5f89 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
class XlaAllocator;
@@ -43,7 +44,7 @@ class XlaAllocator;
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables);
+ OpKernelContext* ctx, absl::Span<const int> variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation: