aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-11 00:50:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 00:54:33 -0700
commit45965cfd8b54fb113275ffdaced5366e28aa3553 (patch)
tree253c390dceb910360cb3b62d5039bcbcdf0f5c5d /tensorflow/compiler/jit
parent5375f8c48b3087512f7593cf699346cc0b30a27b (diff)
Graph optimization pass that creates XlaLaunch ops for the computations that have been explicitly marked to be compiled via xla.compile()
PiperOrigin-RevId: 212407112
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD6
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc17
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h6
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc360
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.h61
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc346
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc7
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc19
8 files changed, 822 insertions, 0 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index a989f15a1c..352f63bc98 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -362,6 +362,7 @@ cc_library(
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
+ "encapsulate_xla_computations_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@@ -370,6 +371,7 @@ cc_library(
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
+ "encapsulate_xla_computations_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
@@ -396,6 +398,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
@@ -474,6 +477,7 @@ tf_cc_test(
size = "small",
srcs = [
"encapsulate_subgraphs_pass_test.cc",
+ "encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
@@ -489,7 +493,9 @@ tf_cc_test(
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index ae7a22f451..e0632ff7e4 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -58,6 +59,22 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
+void SortControlInputs(GraphDef* gdef) {
+ int64 num_nodes = gdef->node_size();
+ for (int64 i = 0; i < num_nodes; ++i) {
+ NodeDef* node = gdef->mutable_node(i);
+ // Stable sort control inputs and leave the order of data inputs unchanged.
+ std::stable_sort(node->mutable_input()->begin(),
+ node->mutable_input()->end(),
+ [](const string& a, const string& b) {
+ bool a_is_control = absl::StartsWith(a, "^");
+ bool b_is_control = absl::StartsWith(b, "^");
+ return (!a_is_control && b_is_control) ||
+ (a_is_control && b_is_control && a < b);
+ });
+ }
+}
+
namespace {
bool AreAllParentsGuaranteedConst(
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 926589546f..90354a801a 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -102,6 +102,12 @@ extern const char* const kXlaNumConstantArgsAttr;
// Name of the attribute containing the number of resource variable arguments.
extern const char* const kXlaNumResourceArgsAttr;
+// Sorts each node's control inputs by their names. This guarantees that for two
+// structually equivalent GraphDefs, we get the same traversal ordering on
+// node's control input fields.
+// TODO(hpucha): Move the utilities to a more appropriate place.
+void SortControlInputs(GraphDef* gdef);
+
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
new file mode 100644
index 0000000000..97ef8cd3cb
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -0,0 +1,360 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/fingerprint.h"
+
+namespace tensorflow {
+
+const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
+ "_xla_compile_id";
+
+namespace {
+
+const char* const kXlaClusterOutput = "XlaClusterOutput";
+
+// Checks if a graph node is marked to be a guaranteed constant.
+bool is_guaranteed_constant(const Node& n) {
+ bool guaranteed_constant = false;
+ if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
+ .ok()) {
+ return false;
+ }
+ return guaranteed_constant;
+}
+
+// Finds the `index` of an _Arg or _Retval node.
+Status GetIndexAttr(const Node& n, int num_args, int* index) {
+ TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
+ if (*index < 0 || *index >= num_args) {
+ return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
+ *index);
+ }
+ return Status::OK();
+}
+
+// Returns the data type of the destination of an edge.
+DataType EdgeType(const Edge* edge) {
+ return edge->dst()->input_type(edge->dst_input());
+}
+
+// Adds the control inputs of `node` to `*deps`.
+void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.in_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->src());
+ }
+ }
+}
+
+// Adds the control outputs of `node` to `*deps`.
+void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+ for (const Edge* edge : node.out_edges()) {
+ if (edge->IsControlEdge()) {
+ deps->insert(edge->dst());
+ }
+ }
+}
+
+// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
+// the arguments into the order expected by XlaLaunch computations:
+// 1) arguments
+// 2) resource variable arguments
+// See the documentation of EncapsulateSubgraphsInFunctions for the meaning
+// of the arguments.
+//
+// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed.
+Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
+ std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation,
+ NodeDef* call_def) {
+ Graph* graph = graph_ptr->get();
+ const int num_args = input_permutation->size();
+ const int num_retvals = output_permutation->size();
+
+ std::vector<Node*> args;
+ std::vector<Node*> retvals;
+ args.reserve(num_args);
+ retvals.reserve(num_retvals);
+ for (Node* n : graph->nodes()) {
+ if (n->type_string() == "_Arg") {
+ // Check if this is a guaranteed constant.
+ if (is_guaranteed_constant(*n)) {
+ return errors::InvalidArgument(
+ "Guaranteed constants are not supported (", n->name(), ")");
+ }
+ args.push_back(n);
+ } else if (n->type_string() == "_Retval") {
+ retvals.push_back(n);
+ }
+ }
+
+ if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
+ return errors::InvalidArgument("Missing or non-consecutive arguments");
+ }
+
+ // Reorders the arguments.
+ std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
+ // Non-resources appear before resources
+ bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
+ bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
+ // Uses the name as a tiebreaker so the output is deterministic.
+ StringPiece a_name(a->name());
+ StringPiece b_name(b->name());
+ return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name);
+ });
+
+ // Sorts the retvals by name so the order is deterministic.
+ std::sort(retvals.begin(), retvals.end(),
+ [](Node* a, Node* b) { return a->name() < b->name(); });
+
+ // Computes the permutation to produce the correct argument order, and update
+ // the argument indices.
+ int variable_start_index = num_args;
+ for (int i = 0; i < num_args; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
+ if (args[i]->output_type(0) == DT_RESOURCE &&
+ variable_start_index == num_args) {
+ variable_start_index = i;
+ }
+ (*input_permutation)[index] = i;
+ args[i]->AddAttr("index", i);
+ }
+ VLOG(4) << "variable_start_index: " << variable_start_index;
+
+ // Computes the permutation to produce the correct retval order, and update
+ // the argument indices.
+ for (int i = 0; i < num_retvals; ++i) {
+ int index;
+ TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
+ (*output_permutation)[index] = i;
+ retvals[i]->AddAttr("index", i);
+ }
+
+ AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
+ call_def);
+ AddNodeAttr("_variable_start_index", variable_start_index, call_def);
+
+ // Uniquify the function name.
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+
+ // Before serialization, sort each node's control inputs to achieve
+ // determinism. Sorting control inputs could help (but not necessarily) create
+ // a deterministic serialization and fingerprint. Other sources of
+ // nondeterminism include unstable node ordering.
+ SortControlInputs(&gdef);
+ // Fingerprint the function.
+ // Nondeterminism in serialization would not lead to incorrect results, but
+ // may cause spurious cache misses. DeterministicSerialization is a
+ // best-effort deterministic serialization.
+ string serialized;
+ TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
+ uint64 fingerprint = Fingerprint64(serialized);
+ LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
+ call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
+ return Status::OK();
+}
+
+} // namespace
+
+/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ // Check for undeclared outputs before Encapsulation, so we can give a better
+ // error message.
+ // TODO(phawkins): merge this with the encapsulation code to avoid the extra
+ // O(n) pass over the edges.
+ for (const Edge* e : (*graph)->edges()) {
+ if (!e->IsControlEdge() &&
+ e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
+ e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
+ e->dst()->type_string() != kXlaClusterOutput) {
+ return errors::InvalidArgument(
+ "Undeclared output of XLA computation. A common cause of this error "
+ "is variable initializers that depend on the XLA computation. Edge: ",
+ e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
+ e->dst_input());
+ }
+ }
+
+ auto output = absl::make_unique<Graph>((*graph)->op_registry());
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EncapsulateSubgraphsInFunctions(
+ kXlaClusterAttr, "", **graph, RewriteSubgraph,
+ /*reuse_existing_functions=*/true, &output, flib_def),
+ "EncapsulateXlaComputationsPass failed");
+ graph->swap(output);
+ return Status::OK();
+}
+
+/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
+ Graph* graph) {
+ // Finds all of the XlaLaunch function calls, to avoid mutating the graph
+ // while iterating.
+ std::vector<Node*> launch_nodes;
+ for (Node* n : graph->nodes()) {
+ string name;
+ if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
+ launch_nodes.push_back(n);
+ }
+ }
+
+ // Replaces each launch function call together with its neighboring
+ // XlaClusterOutput nodes with a XlaLaunch node.
+ for (Node* launch : launch_nodes) {
+ int variable_start_index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index",
+ &variable_start_index));
+
+ std::vector<const Edge*> in_edges;
+ TF_RETURN_IF_ERROR(launch->input_edges(&in_edges));
+
+ const int num_inputs = in_edges.size();
+ const int num_variables = num_inputs - variable_start_index;
+ const int num_args = variable_start_index;
+
+ VLOG(4) << "Launch node '" << launch->name() << "'"
+ << " input edges: " << in_edges.size() << " num_args: " << num_args
+ << " num_variables: " << num_variables;
+
+ std::vector<Node*> nodes_to_remove = {launch};
+
+ // Data and control inputs to the new XlaLaunch node.
+ std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
+ gtl::FlatSet<Node*> control_inputs;
+ DataTypeVector arg_types(num_args);
+
+ AddControlInputs(*launch, &control_inputs);
+
+ for (int i = 0; i < num_args; ++i) {
+ const Edge* edge = in_edges[i];
+ data_inputs[i] = {edge->src(), edge->src_output()};
+ arg_types[i] = EdgeType(edge);
+ }
+
+ // Appends the variable inputs.
+ for (int i = 0; i < num_variables; ++i) {
+ int pos = variable_start_index + i;
+ const Edge* edge = in_edges[pos];
+ data_inputs[pos] = {edge->src(), edge->src_output()};
+ }
+
+ // Outputs.
+ const int num_outputs = launch->output_types().size();
+ gtl::FlatSet<Node*> control_outputs;
+ std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
+ DataTypeVector output_types(num_outputs);
+
+ for (const Edge* le : launch->out_edges()) {
+ if (le->IsControlEdge()) {
+ control_outputs.insert(le->dst());
+ } else {
+ TF_RET_CHECK(le->src_output() < num_outputs);
+ Node* output_node = le->dst();
+
+ TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput)
+ << le->DebugString();
+ nodes_to_remove.push_back(output_node);
+
+ for (const Edge* oe : output_node->out_edges()) {
+ TF_RET_CHECK(!oe->IsControlEdge());
+ data_outputs[le->src_output()].push_back(
+ {oe->dst(), oe->dst_input()});
+ }
+ output_types[le->src_output()] = output_node->input_type(0);
+
+ AddControlOutputs(*output_node, &control_outputs);
+ }
+ }
+
+ NodeDef def;
+ def.set_name(launch->name());
+
+ // Target the XLA CPU/GPU backends.
+ VLOG(2) << "Replacing with XlaLaunch";
+ def.set_op("XlaLaunch");
+ AddNodeAttr("Tconstants", DataTypeVector{}, &def);
+ AddNodeAttr("Targs", arg_types, &def);
+ AddNodeAttr("Nresources", num_variables, &def);
+ AddNodeAttr("Tresults", output_types, &def);
+ NameAttrList function;
+ function.set_name(launch->type_string());
+ AddNodeAttr("function", function, &def);
+
+ for (Node* node : nodes_to_remove) {
+ VLOG(2) << "Deleting node " << node->DebugString();
+ // Ensure that we do not attempt to add control edges to nodes that are
+ // deleted.
+ control_inputs.erase(node);
+ control_outputs.erase(node);
+ graph->RemoveNode(node);
+ }
+
+ Status status;
+ Node* xla_launch = graph->AddNode(def, &status);
+ if (!status.ok()) {
+ return status;
+ }
+ for (int i = 0; i < data_inputs.size(); ++i) {
+ graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch,
+ i);
+ }
+ for (Node* n : control_inputs) {
+ graph->AddControlEdge(n, xla_launch);
+ }
+ for (int i = 0; i < data_outputs.size(); ++i) {
+ for (const auto& successor : data_outputs[i]) {
+ graph->AddEdge(xla_launch, i, successor.first, successor.second);
+ }
+ }
+ for (Node* n : control_outputs) {
+ graph->AddControlEdge(xla_launch, n);
+ }
+ }
+ return Status::OK();
+}
+
+Status EncapsulateXlaComputationsPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ VLOG(1) << "EncapsulateXlaComputations(): "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
+ VLOG(1) << "EncapsulateXlaComputations() half-way: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
+ **options.graph, options.flib_def);
+
+ TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
+ VLOG(1) << "EncapsulateXlaComputations() finished: "
+ << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
+ **options.graph, options.flib_def);
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
new file mode 100644
index 0000000000..c8bb4dc114
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Rewrites computations generated by the xla.compile() Python code into
+// XlaLaunch nodes.
+//
+// xla.compile() does two main things:
+// a) marks operators that make up a XLA computation with the attribute
+// _xla_compile_id=XYZ, where XYZ is a unique key.
+// b) adds XlaClusterOutput nodes to represent outputs of the computation.
+// These nodes are not marked with the _xla_compile_id attribute.
+
+#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+// Encapsulates nodes marked with the _xla_compile_id attribute into
+// XlaLaunch operators.
+class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
+ public:
+ static const char* const kXlaClusterAttr; // _xla_compile_id
+
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ // The following methods are public only for unit tests.
+
+ // This pass has two stages:
+ // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes
+ // marked with the same _xla_compile_id attribute into functions. These
+ // functions contain the computations to be passed to XlaLaunch. During
+ // encapsulation, we sort the arguments into the order expected by
+ // XlaLaunch.
+ static Status Encapsulate(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def);
+
+ // b) we rewrite the function calls generated in phase (a) into XlaLaunch
+ // operators. We also convert the XlaClusterOutput output nodes of the
+ // function call into the outputs of the XlaLaunch operator.
+ static Status BuildXlaLaunchOps(Graph* graph);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
new file mode 100644
index 0000000000..f643fb0cfe
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -0,0 +1,346 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/test_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+static std::unique_ptr<Graph> MakeOuterGraph(
+ const FunctionLibraryDefinition& flib_def, const string& function) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NodeDef def;
+ TF_CHECK_OK(
+ NodeDefBuilder("launch0", function, &flib_def)
+ .Input(a.node()->name(), 0, DT_INT32)
+ .Input(b.node()->name(), 0, DT_FLOAT)
+ .Input(c.node()->name(), 0, DT_INT32)
+ .Input(d.node()->name(), 0, DT_FLOAT)
+ .Input(u.node()->name(), 0, DT_RESOURCE)
+ .Input(v.node()->name(), 0, DT_RESOURCE)
+ .Input(w.node()->name(), 0, DT_RESOURCE)
+ .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
+ .Attr("_variable_start_index", 4)
+ .Finalize(&def));
+
+ Status status;
+ Node* launch = scope.graph()->AddNode(def, &status);
+ TF_CHECK_OK(status);
+ TF_CHECK_OK(scope.DoShapeInference(launch));
+ scope.graph()->AddEdge(a.node(), 0, launch, 0);
+ scope.graph()->AddEdge(b.node(), 0, launch, 1);
+ scope.graph()->AddEdge(c.node(), 0, launch, 2);
+ scope.graph()->AddEdge(d.node(), 0, launch, 3);
+ scope.graph()->AddEdge(u.node(), 0, launch, 4);
+ scope.graph()->AddEdge(v.node(), 0, launch, 5);
+ scope.graph()->AddEdge(w.node(), 0, launch, 6);
+
+ auto out0 =
+ ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
+ auto out1 =
+ ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
+ auto out2 =
+ ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
+ auto out3 =
+ ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+// Makes an encapsulate body graph for use in tests.
+static std::unique_ptr<Graph> MakeBodyGraph() {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+
+ auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
+ auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
+ auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
+ auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
+
+ auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
+ auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
+ auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, arg3);
+ add_attrs(g.node());
+
+ auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
+ b_identity, 0);
+ auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
+ auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
+ auto out3 =
+ ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ return graph;
+}
+
+TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
+ // Test that control edge insertion order doesn't affect the cache key
+ // (cluster name) generated by TPU encapsulate pass.
+ auto get_serialized_graph = [](bool control_input_reversed,
+ bool operand_reversed) -> string {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
+ auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
+
+ ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
+ : ops::Add(scope.WithOpName("E"), a1, a0);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
+ "launch0");
+ };
+ add_attrs(e.node());
+
+ TF_CHECK_OK(scope.ToGraph(graph.get()));
+ auto get_node_in_graph = [&graph](Node* node) {
+ return graph->FindNodeId(node->id());
+ };
+ // Insert control edge in different order. The order should not affect
+ // the encapsulated or serialized graph.
+ if (!control_input_reversed) {
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ } else {
+ graph->AddControlEdge(get_node_in_graph(a1.node()),
+ get_node_in_graph(e.node()), true);
+ graph->AddControlEdge(get_node_in_graph(a0.node()),
+ get_node_in_graph(e.node()), true);
+ }
+ }
+ TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+ GraphDef gdef;
+ graph->ToGraphDef(&gdef);
+ // Before serialization, sort control inputs first to remove
+ // nondeterminism.
+ SortControlInputs(&gdef);
+ string serialized;
+ SerializeToStringDeterministic(gdef, &serialized);
+ return serialized;
+ };
+
+ // Changing the order of control input shouldn't affect the graph generated.
+ EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
+ /*operand_reversed=*/false),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+
+ // Changing the order of data input should affect the graph generated.
+ EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/true),
+ get_serialized_graph(/*control_input_reversed=*/false,
+ /*operand_reversed=*/false));
+}
+
+TEST(EncapsulateXlaComputations, Encapsulate) {
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ auto add_attrs = [](Node* node) {
+ node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ };
+
+ auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
+ add_attrs(b_identity.node());
+
+ auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
+ add_attrs(read_u.node());
+ auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
+ add_attrs(read_v.node());
+ auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
+ add_attrs(read_w.node());
+
+ auto e = ops::Add(scope.WithOpName("E"), a, c);
+ add_attrs(e.node());
+ auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
+ add_attrs(f.node());
+ auto g = ops::Add(scope.WithOpName("G"), f, d);
+ add_attrs(g.node());
+
+ auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
+ auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
+ auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
+ auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
+
+ auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
+ auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
+ auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
+ auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
+ auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
+ auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ }
+
+ std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
+ CopyGraph(*graph, graph_copy.get());
+
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
+
+ std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
+ string function = index.at("launch0")->type_string();
+
+ // Tests the outer graph is as expected.
+ {
+ std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
+ GraphDef expected_def;
+ outer->ToGraphDef(&expected_def);
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
+ }
+
+ // Tests the encapsulated body graph is as expected.
+ {
+ std::unique_ptr<Graph> body = MakeBodyGraph();
+ GraphDef expected_body_def;
+ body->ToGraphDef(&expected_body_def);
+
+ InstantiationResultForTest result;
+ TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
+
+ EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
+ DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
+ result.arg_types);
+ EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
+ result.ret_types);
+ TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
+ }
+
+ // Encapsulates the same computation again, verifies we reuse the same
+ // function. Encapsulation should be deterministic to avoid recompilation.
+ TF_ASSERT_OK(
+ EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
+ std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
+ string function_copy = index_copy.at("launch0")->type_string();
+ EXPECT_EQ(function, function_copy);
+}
+
+TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
+ std::unique_ptr<Graph> body_graph = MakeBodyGraph();
+ FunctionDefLibrary flib;
+ TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
+
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
+
+ std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
+ TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
+
+ Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
+ TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
+
+ auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
+ auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
+ auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
+ auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
+ auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
+ auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
+
+ NameAttrList function;
+ function.set_name("launch0");
+ auto launch = ops::XlaLaunch(
+ scope.WithOpName("launch0"), std::initializer_list<Input>{},
+ std::initializer_list<Input>{a, b, c, d},
+ std::initializer_list<Input>{u, v, w},
+ DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
+
+ auto consumer0_a =
+ ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
+ auto consumer0_b =
+ ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
+ auto consumer0_c =
+ ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
+ auto consumer1 =
+ ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
+ auto consumer2 =
+ ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
+ auto consumer3 =
+ ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
+
+ GraphDef expected_def;
+ TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
+
+ GraphDef actual_def;
+ graph->ToGraphDef(&actual_def);
+ TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index c37b6112cc..315fcb2fa7 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -15,12 +15,19 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
+// EncapsulateXlaComputationsPass rewrites computations generated by the
+// xla.compile() Python code into XlaLaunch nodes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
+ EncapsulateXlaComputationsPass);
+
+// The following POST_REWRITE passes support auto-clustering to enable XLA.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index f2473d98ff..1a29c3caab 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -13,10 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
@@ -32,4 +36,19 @@ REGISTER_OP("XlaLaunch")
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
+REGISTER_OP("XlaClusterOutput")
+ .Input("input: T")
+ // Note: when replication is supported, this op will have N outputs.
+ .Output("outputs: T")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(0));
+ }
+ return Status::OK();
+ })
+ .Doc(
+ "Operator that connects the output of an XLA computation to other "
+ "consumer graph nodes.");
+
} // namespace tensorflow