aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Yanan Cao <ycao@google.com>2018-09-11 09:33:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 09:38:56 -0700
commitac60b46e2c5962fd8099a4406c1788d826ad3c0d (patch)
tree9ddf45e02b1d7cd0828bc1a216bc4a58af7a562d /tensorflow/compiler/jit
parent847b38406a28546991b62193278ee87910cd3d74 (diff)
Automated rollback of commit 45965cfd8b54fb113275ffdaced5366e28aa3553
PiperOrigin-RevId: 212465918
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD6
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc17
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h6
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc360
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.h61
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc346
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc7
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc19
8 files changed, 0 insertions, 822 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 352f63bc98..a989f15a1c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -362,7 +362,6 @@ cc_library(
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
- "encapsulate_xla_computations_pass.cc",
"mark_for_compilation_pass.cc",
"mark_for_compilation_pass_test_helper.cc",
"partially_decluster_pass.cc",
@@ -371,7 +370,6 @@ cc_library(
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
- "encapsulate_xla_computations_pass.h",
"mark_for_compilation_pass.h",
"mark_for_compilation_pass_test_helper.h",
"partially_decluster_pass.h",
@@ -398,7 +396,6 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
- "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
@@ -477,7 +474,6 @@ tf_cc_test(
size = "small",
srcs = [
"encapsulate_subgraphs_pass_test.cc",
- "encapsulate_xla_computations_pass_test.cc",
"mark_for_compilation_pass_test.cc",
"partially_decluster_pass_test.cc",
],
@@ -493,9 +489,7 @@ tf_cc_test(
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/compiler/jit/kernels:xla_launch_op",
- "//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
- "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index e0632ff7e4..ae7a22f451 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -59,22 +58,6 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";
const char* const kXlaHostTransferSequencerAttr =
"_xla_host_transfer_sequencer";
-void SortControlInputs(GraphDef* gdef) {
- int64 num_nodes = gdef->node_size();
- for (int64 i = 0; i < num_nodes; ++i) {
- NodeDef* node = gdef->mutable_node(i);
- // Stable sort control inputs and leave the order of data inputs unchanged.
- std::stable_sort(node->mutable_input()->begin(),
- node->mutable_input()->end(),
- [](const string& a, const string& b) {
- bool a_is_control = absl::StartsWith(a, "^");
- bool b_is_control = absl::StartsWith(b, "^");
- return (!a_is_control && b_is_control) ||
- (a_is_control && b_is_control && a < b);
- });
- }
-}
-
namespace {
bool AreAllParentsGuaranteedConst(
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 90354a801a..926589546f 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -102,12 +102,6 @@ extern const char* const kXlaNumConstantArgsAttr;
// Name of the attribute containing the number of resource variable arguments.
extern const char* const kXlaNumResourceArgsAttr;
-// Sorts each node's control inputs by their names. This guarantees that for two
-// structually equivalent GraphDefs, we get the same traversal ordering on
-// node's control input fields.
-// TODO(hpucha): Move the utilities to a more appropriate place.
-void SortControlInputs(GraphDef* gdef);
-
class EncapsulateSubgraphsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
deleted file mode 100644
index 97ef8cd3cb..0000000000
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ /dev/null
@@ -1,360 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
-
-#include "absl/memory/memory.h"
-#include "absl/strings/str_cat.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/proto_serialization.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/fingerprint.h"
-
-namespace tensorflow {
-
-const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr =
- "_xla_compile_id";
-
-namespace {
-
-const char* const kXlaClusterOutput = "XlaClusterOutput";
-
-// Checks if a graph node is marked to be a guaranteed constant.
-bool is_guaranteed_constant(const Node& n) {
- bool guaranteed_constant = false;
- if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant)
- .ok()) {
- return false;
- }
- return guaranteed_constant;
-}
-
-// Finds the `index` of an _Arg or _Retval node.
-Status GetIndexAttr(const Node& n, int num_args, int* index) {
- TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index));
- if (*index < 0 || *index >= num_args) {
- return errors::InvalidArgument("Invalid ", n.type_string(), " number ",
- *index);
- }
- return Status::OK();
-}
-
-// Returns the data type of the destination of an edge.
-DataType EdgeType(const Edge* edge) {
- return edge->dst()->input_type(edge->dst_input());
-}
-
-// Adds the control inputs of `node` to `*deps`.
-void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
- for (const Edge* edge : node.in_edges()) {
- if (edge->IsControlEdge()) {
- deps->insert(edge->src());
- }
- }
-}
-
-// Adds the control outputs of `node` to `*deps`.
-void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
- for (const Edge* edge : node.out_edges()) {
- if (edge->IsControlEdge()) {
- deps->insert(edge->dst());
- }
- }
-}
-
-// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts
-// the arguments into the order expected by XlaLaunch computations:
-// 1) arguments
-// 2) resource variable arguments
-// See the documentation of EncapsulateSubgraphsInFunctions for the meaning
-// of the arguments.
-//
-// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed.
-Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
- std::unique_ptr<Graph>* graph_ptr,
- std::vector<int>* input_permutation,
- std::vector<int>* output_permutation,
- NodeDef* call_def) {
- Graph* graph = graph_ptr->get();
- const int num_args = input_permutation->size();
- const int num_retvals = output_permutation->size();
-
- std::vector<Node*> args;
- std::vector<Node*> retvals;
- args.reserve(num_args);
- retvals.reserve(num_retvals);
- for (Node* n : graph->nodes()) {
- if (n->type_string() == "_Arg") {
- // Check if this is a guaranteed constant.
- if (is_guaranteed_constant(*n)) {
- return errors::InvalidArgument(
- "Guaranteed constants are not supported (", n->name(), ")");
- }
- args.push_back(n);
- } else if (n->type_string() == "_Retval") {
- retvals.push_back(n);
- }
- }
-
- if (std::find(args.begin(), args.end(), nullptr) != args.end()) {
- return errors::InvalidArgument("Missing or non-consecutive arguments");
- }
-
- // Reorders the arguments.
- std::sort(args.begin(), args.end(), [&](Node* a, Node* b) {
- // Non-resources appear before resources
- bool a_is_resource = (a->output_type(0) == DT_RESOURCE);
- bool b_is_resource = (b->output_type(0) == DT_RESOURCE);
- // Uses the name as a tiebreaker so the output is deterministic.
- StringPiece a_name(a->name());
- StringPiece b_name(b->name());
- return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name);
- });
-
- // Sorts the retvals by name so the order is deterministic.
- std::sort(retvals.begin(), retvals.end(),
- [](Node* a, Node* b) { return a->name() < b->name(); });
-
- // Computes the permutation to produce the correct argument order, and update
- // the argument indices.
- int variable_start_index = num_args;
- for (int i = 0; i < num_args; ++i) {
- int index;
- TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index));
- if (args[i]->output_type(0) == DT_RESOURCE &&
- variable_start_index == num_args) {
- variable_start_index = i;
- }
- (*input_permutation)[index] = i;
- args[i]->AddAttr("index", i);
- }
- VLOG(4) << "variable_start_index: " << variable_start_index;
-
- // Computes the permutation to produce the correct retval order, and update
- // the argument indices.
- for (int i = 0; i < num_retvals; ++i) {
- int index;
- TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index));
- (*output_permutation)[index] = i;
- retvals[i]->AddAttr("index", i);
- }
-
- AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(),
- call_def);
- AddNodeAttr("_variable_start_index", variable_start_index, call_def);
-
- // Uniquify the function name.
- GraphDef gdef;
- graph->ToGraphDef(&gdef);
-
- // Before serialization, sort each node's control inputs to achieve
- // determinism. Sorting control inputs could help (but not necessarily) create
- // a deterministic serialization and fingerprint. Other sources of
- // nondeterminism include unstable node ordering.
- SortControlInputs(&gdef);
- // Fingerprint the function.
- // Nondeterminism in serialization would not lead to incorrect results, but
- // may cause spurious cache misses. DeterministicSerialization is a
- // best-effort deterministic serialization.
- string serialized;
- TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
- uint64 fingerprint = Fingerprint64(serialized);
- LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
- call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
- return Status::OK();
-}
-
-} // namespace
-
-/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate(
- std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
- // Check for undeclared outputs before Encapsulation, so we can give a better
- // error message.
- // TODO(phawkins): merge this with the encapsulation code to avoid the extra
- // O(n) pass over the edges.
- for (const Edge* e : (*graph)->edges()) {
- if (!e->IsControlEdge() &&
- e->src()->attrs().Find(kXlaClusterAttr) != nullptr &&
- e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
- e->dst()->type_string() != kXlaClusterOutput) {
- return errors::InvalidArgument(
- "Undeclared output of XLA computation. A common cause of this error "
- "is variable initializers that depend on the XLA computation. Edge: ",
- e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
- e->dst_input());
- }
- }
-
- auto output = absl::make_unique<Graph>((*graph)->op_registry());
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- EncapsulateSubgraphsInFunctions(
- kXlaClusterAttr, "", **graph, RewriteSubgraph,
- /*reuse_existing_functions=*/true, &output, flib_def),
- "EncapsulateXlaComputationsPass failed");
- graph->swap(output);
- return Status::OK();
-}
-
-/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps(
- Graph* graph) {
- // Finds all of the XlaLaunch function calls, to avoid mutating the graph
- // while iterating.
- std::vector<Node*> launch_nodes;
- for (Node* n : graph->nodes()) {
- string name;
- if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) {
- launch_nodes.push_back(n);
- }
- }
-
- // Replaces each launch function call together with its neighboring
- // XlaClusterOutput nodes with a XlaLaunch node.
- for (Node* launch : launch_nodes) {
- int variable_start_index;
- TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index",
- &variable_start_index));
-
- std::vector<const Edge*> in_edges;
- TF_RETURN_IF_ERROR(launch->input_edges(&in_edges));
-
- const int num_inputs = in_edges.size();
- const int num_variables = num_inputs - variable_start_index;
- const int num_args = variable_start_index;
-
- VLOG(4) << "Launch node '" << launch->name() << "'"
- << " input edges: " << in_edges.size() << " num_args: " << num_args
- << " num_variables: " << num_variables;
-
- std::vector<Node*> nodes_to_remove = {launch};
-
- // Data and control inputs to the new XlaLaunch node.
- std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
- gtl::FlatSet<Node*> control_inputs;
- DataTypeVector arg_types(num_args);
-
- AddControlInputs(*launch, &control_inputs);
-
- for (int i = 0; i < num_args; ++i) {
- const Edge* edge = in_edges[i];
- data_inputs[i] = {edge->src(), edge->src_output()};
- arg_types[i] = EdgeType(edge);
- }
-
- // Appends the variable inputs.
- for (int i = 0; i < num_variables; ++i) {
- int pos = variable_start_index + i;
- const Edge* edge = in_edges[pos];
- data_inputs[pos] = {edge->src(), edge->src_output()};
- }
-
- // Outputs.
- const int num_outputs = launch->output_types().size();
- gtl::FlatSet<Node*> control_outputs;
- std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
- DataTypeVector output_types(num_outputs);
-
- for (const Edge* le : launch->out_edges()) {
- if (le->IsControlEdge()) {
- control_outputs.insert(le->dst());
- } else {
- TF_RET_CHECK(le->src_output() < num_outputs);
- Node* output_node = le->dst();
-
- TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput)
- << le->DebugString();
- nodes_to_remove.push_back(output_node);
-
- for (const Edge* oe : output_node->out_edges()) {
- TF_RET_CHECK(!oe->IsControlEdge());
- data_outputs[le->src_output()].push_back(
- {oe->dst(), oe->dst_input()});
- }
- output_types[le->src_output()] = output_node->input_type(0);
-
- AddControlOutputs(*output_node, &control_outputs);
- }
- }
-
- NodeDef def;
- def.set_name(launch->name());
-
- // Target the XLA CPU/GPU backends.
- VLOG(2) << "Replacing with XlaLaunch";
- def.set_op("XlaLaunch");
- AddNodeAttr("Tconstants", DataTypeVector{}, &def);
- AddNodeAttr("Targs", arg_types, &def);
- AddNodeAttr("Nresources", num_variables, &def);
- AddNodeAttr("Tresults", output_types, &def);
- NameAttrList function;
- function.set_name(launch->type_string());
- AddNodeAttr("function", function, &def);
-
- for (Node* node : nodes_to_remove) {
- VLOG(2) << "Deleting node " << node->DebugString();
- // Ensure that we do not attempt to add control edges to nodes that are
- // deleted.
- control_inputs.erase(node);
- control_outputs.erase(node);
- graph->RemoveNode(node);
- }
-
- Status status;
- Node* xla_launch = graph->AddNode(def, &status);
- if (!status.ok()) {
- return status;
- }
- for (int i = 0; i < data_inputs.size(); ++i) {
- graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch,
- i);
- }
- for (Node* n : control_inputs) {
- graph->AddControlEdge(n, xla_launch);
- }
- for (int i = 0; i < data_outputs.size(); ++i) {
- for (const auto& successor : data_outputs[i]) {
- graph->AddEdge(xla_launch, i, successor.first, successor.second);
- }
- }
- for (Node* n : control_outputs) {
- graph->AddControlEdge(xla_launch, n);
- }
- }
- return Status::OK();
-}
-
-Status EncapsulateXlaComputationsPass::Run(
- const GraphOptimizationPassOptions& options) {
- VLOG(1) << "EncapsulateXlaComputations(): "
- << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before",
- **options.graph, options.flib_def);
-
- TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def));
- VLOG(1) << "EncapsulateXlaComputations() half-way: "
- << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway",
- **options.graph, options.flib_def);
-
- TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get()));
- VLOG(1) << "EncapsulateXlaComputations() finished: "
- << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after",
- **options.graph, options.flib_def);
- return Status::OK();
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
deleted file mode 100644
index c8bb4dc114..0000000000
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Rewrites computations generated by the xla.compile() Python code into
-// XlaLaunch nodes.
-//
-// xla.compile() does two main things:
-// a) marks operators that make up a XLA computation with the attribute
-// _xla_compile_id=XYZ, where XYZ is a unique key.
-// b) adds XlaClusterOutput nodes to represent outputs of the computation.
-// These nodes are not marked with the _xla_compile_id attribute.
-
-#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
-#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
-
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/platform/env.h"
-
-namespace tensorflow {
-
-// Encapsulates nodes marked with the _xla_compile_id attribute into
-// XlaLaunch operators.
-class EncapsulateXlaComputationsPass : public GraphOptimizationPass {
- public:
- static const char* const kXlaClusterAttr; // _xla_compile_id
-
- Status Run(const GraphOptimizationPassOptions& options) override;
-
- // The following methods are public only for unit tests.
-
- // This pass has two stages:
- // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes
- // marked with the same _xla_compile_id attribute into functions. These
- // functions contain the computations to be passed to XlaLaunch. During
- // encapsulation, we sort the arguments into the order expected by
- // XlaLaunch.
- static Status Encapsulate(std::unique_ptr<Graph>* graph,
- FunctionLibraryDefinition* flib_def);
-
- // b) we rewrite the function calls generated in phase (a) into XlaLaunch
- // operators. We also convert the XlaClusterOutput output nodes of the
- // function call into the outputs of the XlaLaunch operator.
- static Status BuildXlaLaunchOps(Graph* graph);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
deleted file mode 100644
index f643fb0cfe..0000000000
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ /dev/null
@@ -1,346 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
-
-#include "tensorflow/cc/ops/function_ops.h"
-#include "tensorflow/cc/ops/resource_variable_ops.h"
-#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
-#include "tensorflow/compiler/tf2xla/test_util.h"
-#include "tensorflow/core/framework/graph_to_functiondef.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/lib/strings/proto_serialization.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/util/equal_graph_def.h"
-#include "tensorflow/core/util/ptr_util.h"
-
-namespace tensorflow {
-
-static std::unique_ptr<Graph> MakeOuterGraph(
- const FunctionLibraryDefinition& flib_def, const string& function) {
- Scope scope = Scope::NewRootScope().ExitOnError();
- TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto()));
-
- auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
- auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
- auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
- auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
- auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
- auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
- auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
-
- NodeDef def;
- TF_CHECK_OK(
- NodeDefBuilder("launch0", function, &flib_def)
- .Input(a.node()->name(), 0, DT_INT32)
- .Input(b.node()->name(), 0, DT_FLOAT)
- .Input(c.node()->name(), 0, DT_INT32)
- .Input(d.node()->name(), 0, DT_FLOAT)
- .Input(u.node()->name(), 0, DT_RESOURCE)
- .Input(v.node()->name(), 0, DT_RESOURCE)
- .Input(w.node()->name(), 0, DT_RESOURCE)
- .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
- .Attr("_variable_start_index", 4)
- .Finalize(&def));
-
- Status status;
- Node* launch = scope.graph()->AddNode(def, &status);
- TF_CHECK_OK(status);
- TF_CHECK_OK(scope.DoShapeInference(launch));
- scope.graph()->AddEdge(a.node(), 0, launch, 0);
- scope.graph()->AddEdge(b.node(), 0, launch, 1);
- scope.graph()->AddEdge(c.node(), 0, launch, 2);
- scope.graph()->AddEdge(d.node(), 0, launch, 3);
- scope.graph()->AddEdge(u.node(), 0, launch, 4);
- scope.graph()->AddEdge(v.node(), 0, launch, 5);
- scope.graph()->AddEdge(w.node(), 0, launch, 6);
-
- auto out0 =
- ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0));
- auto out1 =
- ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1));
- auto out2 =
- ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2));
- auto out3 =
- ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3));
-
- auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
- auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
- auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
- auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
- auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
- auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
-
- std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
- TF_CHECK_OK(scope.ToGraph(graph.get()));
- return graph;
-}
-
-// Makes an encapsulate body graph for use in tests.
-static std::unique_ptr<Graph> MakeBodyGraph() {
- Scope scope = Scope::NewRootScope().ExitOnError();
-
- auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0);
- auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1);
- auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2);
- auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3);
-
- auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4);
- auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5);
- auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6);
-
- auto add_attrs = [](Node* node) {
- node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
- };
-
- auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
-
- auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
- add_attrs(read_u.node());
- auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
- add_attrs(read_v.node());
- auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT);
- add_attrs(read_w.node());
-
- auto e = ops::Add(scope.WithOpName("E"), arg0, arg2);
- add_attrs(e.node());
- auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
- add_attrs(f.node());
- auto g = ops::Add(scope.WithOpName("G"), f, arg3);
- add_attrs(g.node());
-
- auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"),
- b_identity, 0);
- auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1);
- auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2);
- auto out3 =
- ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3);
-
- std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
- TF_CHECK_OK(scope.ToGraph(graph.get()));
- return graph;
-}
-
-TEST(EncapsulateXlaComputations, DeterministicEncapsulate) {
- // Test that control edge insertion order doesn't affect the cache key
- // (cluster name) generated by TPU encapsulate pass.
- auto get_serialized_graph = [](bool control_input_reversed,
- bool operand_reversed) -> string {
- FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
- std::unique_ptr<Graph> graph(new Graph(&flib_def));
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32);
- auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32);
-
- ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1)
- : ops::Add(scope.WithOpName("E"), a1, a0);
-
- auto add_attrs = [](Node* node) {
- node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr,
- "launch0");
- };
- add_attrs(e.node());
-
- TF_CHECK_OK(scope.ToGraph(graph.get()));
- auto get_node_in_graph = [&graph](Node* node) {
- return graph->FindNodeId(node->id());
- };
- // Insert control edge in different order. The order should not affect
- // the encapsulated or serialized graph.
- if (!control_input_reversed) {
- graph->AddControlEdge(get_node_in_graph(a0.node()),
- get_node_in_graph(e.node()), true);
- graph->AddControlEdge(get_node_in_graph(a1.node()),
- get_node_in_graph(e.node()), true);
- } else {
- graph->AddControlEdge(get_node_in_graph(a1.node()),
- get_node_in_graph(e.node()), true);
- graph->AddControlEdge(get_node_in_graph(a0.node()),
- get_node_in_graph(e.node()), true);
- }
- }
- TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
- GraphDef gdef;
- graph->ToGraphDef(&gdef);
- // Before serialization, sort control inputs first to remove
- // nondeterminism.
- SortControlInputs(&gdef);
- string serialized;
- SerializeToStringDeterministic(gdef, &serialized);
- return serialized;
- };
-
- // Changing the order of control input shouldn't affect the graph generated.
- EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true,
- /*operand_reversed=*/false),
- get_serialized_graph(/*control_input_reversed=*/false,
- /*operand_reversed=*/false));
-
- // Changing the order of data input should affect the graph generated.
- EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false,
- /*operand_reversed=*/true),
- get_serialized_graph(/*control_input_reversed=*/false,
- /*operand_reversed=*/false));
-}
-
-TEST(EncapsulateXlaComputations, Encapsulate) {
- FunctionLibraryDefinition flib_def(OpRegistry::Global(), {});
- std::unique_ptr<Graph> graph(new Graph(&flib_def));
- {
- Scope scope = Scope::NewRootScope().ExitOnError();
- auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
- auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
- auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
- auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
- auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
- auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
- auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
-
- auto add_attrs = [](Node* node) {
- node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
- };
-
- auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
- add_attrs(b_identity.node());
-
- auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT);
- add_attrs(read_u.node());
- auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT);
- add_attrs(read_v.node());
- auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT);
- add_attrs(read_w.node());
-
- auto e = ops::Add(scope.WithOpName("E"), a, c);
- add_attrs(e.node());
- auto f = ops::Add(scope.WithOpName("F"), read_v, read_w);
- add_attrs(f.node());
- auto g = ops::Add(scope.WithOpName("G"), f, d);
- add_attrs(g.node());
-
- auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity);
- auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e);
- auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g);
- auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u);
-
- auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0);
- auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0);
- auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0);
- auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1);
- auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2);
- auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3);
- TF_ASSERT_OK(scope.ToGraph(graph.get()));
- }
-
- std::unique_ptr<Graph> graph_copy(new Graph(&flib_def));
- CopyGraph(*graph, graph_copy.get());
-
- TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def));
-
- std::unordered_map<string, Node*> index = BuildNodeIndex(*graph);
- string function = index.at("launch0")->type_string();
-
- // Tests the outer graph is as expected.
- {
- std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function);
- GraphDef expected_def;
- outer->ToGraphDef(&expected_def);
-
- GraphDef actual_def;
- graph->ToGraphDef(&actual_def);
- TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def);
- }
-
- // Tests the encapsulated body graph is as expected.
- {
- std::unique_ptr<Graph> body = MakeBodyGraph();
- GraphDef expected_body_def;
- body->ToGraphDef(&expected_body_def);
-
- InstantiationResultForTest result;
- TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result));
-
- EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT,
- DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}),
- result.arg_types);
- EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}),
- result.ret_types);
- TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef);
- }
-
- // Encapsulates the same computation again, verifies we reuse the same
- // function. Encapsulation should be deterministic to avoid recompilation.
- TF_ASSERT_OK(
- EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def));
- std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy);
- string function_copy = index_copy.at("launch0")->type_string();
- EXPECT_EQ(function, function_copy);
-}
-
-TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
- std::unique_ptr<Graph> body_graph = MakeBodyGraph();
- FunctionDefLibrary flib;
- TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function()));
-
- FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
-
- std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0");
- TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get()));
-
- Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError();
- TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
-
- auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32);
- auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT);
- auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32);
- auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT);
- auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE);
- auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE);
- auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE);
-
- NameAttrList function;
- function.set_name("launch0");
- auto launch = ops::XlaLaunch(
- scope.WithOpName("launch0"), std::initializer_list<Input>{},
- std::initializer_list<Input>{a, b, c, d},
- std::initializer_list<Input>{u, v, w},
- DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
-
- auto consumer0_a =
- ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]);
- auto consumer0_b =
- ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]);
- auto consumer0_c =
- ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]);
- auto consumer1 =
- ops::Identity(scope.WithOpName("consumer1"), launch.results[1]);
- auto consumer2 =
- ops::Identity(scope.WithOpName("consumer2"), launch.results[2]);
- auto consumer3 =
- ops::Identity(scope.WithOpName("consumer3"), launch.results[3]);
-
- GraphDef expected_def;
- TF_ASSERT_OK(scope.ToGraphDef(&expected_def));
-
- GraphDef actual_def;
- graph->ToGraphDef(&actual_def);
- TF_EXPECT_GRAPH_EQ(expected_def, actual_def);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 315fcb2fa7..c37b6112cc 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -15,19 +15,12 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
-// EncapsulateXlaComputationsPass rewrites computations generated by the
-// xla.compile() Python code into XlaLaunch nodes.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
- EncapsulateXlaComputationsPass);
-
-// The following POST_REWRITE passes support auto-clustering to enable XLA.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 1a29c3caab..f2473d98ff 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -13,14 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
-using shape_inference::InferenceContext;
-
REGISTER_OP("XlaLaunch")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
@@ -36,19 +32,4 @@ REGISTER_OP("XlaLaunch")
.SetIsStateful()
.Doc("XLA Launch Op. For use by the XLA JIT only.");
-REGISTER_OP("XlaClusterOutput")
- .Input("input: T")
- // Note: when replication is supported, this op will have N outputs.
- .Output("outputs: T")
- .Attr("T: type")
- .SetShapeFn([](InferenceContext* c) {
- for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->input(0));
- }
- return Status::OK();
- })
- .Doc(
- "Operator that connects the output of an XLA computation to other "
- "consumer graph nodes.");
-
} // namespace tensorflow