diff options
author | Yanan Cao <ycao@google.com> | 2018-09-11 09:33:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 09:38:56 -0700 |
commit | ac60b46e2c5962fd8099a4406c1788d826ad3c0d (patch) | |
tree | 9ddf45e02b1d7cd0828bc1a216bc4a58af7a562d /tensorflow/compiler/jit | |
parent | 847b38406a28546991b62193278ee87910cd3d74 (diff) |
Automated rollback of commit 45965cfd8b54fb113275ffdaced5366e28aa3553
PiperOrigin-RevId: 212465918
Diffstat (limited to 'tensorflow/compiler/jit')
8 files changed, 0 insertions, 822 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 352f63bc98..a989f15a1c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -362,7 +362,6 @@ cc_library( "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", - "encapsulate_xla_computations_pass.cc", "mark_for_compilation_pass.cc", "mark_for_compilation_pass_test_helper.cc", "partially_decluster_pass.cc", @@ -371,7 +370,6 @@ cc_library( "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", - "encapsulate_xla_computations_pass.h", "mark_for_compilation_pass.h", "mark_for_compilation_pass_test_helper.h", "partially_decluster_pass.h", @@ -398,7 +396,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) @@ -477,7 +474,6 @@ tf_cc_test( size = "small", srcs = [ "encapsulate_subgraphs_pass_test.cc", - "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", "partially_decluster_pass_test.cc", ], @@ -493,9 +489,7 @@ tf_cc_test( "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", "//tensorflow/compiler/jit/kernels:xla_launch_op", - "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index e0632ff7e4..ae7a22f451 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,7 +22,6 @@ limitations under the License. #include <unordered_map> #include <vector> -#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -59,22 +58,6 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; -void SortControlInputs(GraphDef* gdef) { - int64 num_nodes = gdef->node_size(); - for (int64 i = 0; i < num_nodes; ++i) { - NodeDef* node = gdef->mutable_node(i); - // Stable sort control inputs and leave the order of data inputs unchanged. - std::stable_sort(node->mutable_input()->begin(), - node->mutable_input()->end(), - [](const string& a, const string& b) { - bool a_is_control = absl::StartsWith(a, "^"); - bool b_is_control = absl::StartsWith(b, "^"); - return (!a_is_control && b_is_control) || - (a_is_control && b_is_control && a < b); - }); - } -} - namespace { bool AreAllParentsGuaranteedConst( diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 90354a801a..926589546f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -102,12 +102,6 @@ extern const char* const kXlaNumConstantArgsAttr; // Name of the attribute containing the number of resource variable arguments. extern const char* const kXlaNumResourceArgsAttr; -// Sorts each node's control inputs by their names. This guarantees that for two -// structually equivalent GraphDefs, we get the same traversal ordering on -// node's control input fields. -// TODO(hpucha): Move the utilities to a more appropriate place. -void SortControlInputs(GraphDef* gdef); - class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc deleted file mode 100644 index 97ef8cd3cb..0000000000 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ /dev/null @@ -1,360 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/fingerprint.h" - -namespace tensorflow { - -const char* const EncapsulateXlaComputationsPass::kXlaClusterAttr = - "_xla_compile_id"; - -namespace { - -const char* const kXlaClusterOutput = "XlaClusterOutput"; - -// Checks if a graph node is marked to be a guaranteed constant. -bool is_guaranteed_constant(const Node& n) { - bool guaranteed_constant = false; - if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", &guaranteed_constant) - .ok()) { - return false; - } - return guaranteed_constant; -} - -// Finds the `index` of an _Arg or _Retval node. -Status GetIndexAttr(const Node& n, int num_args, int* index) { - TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); - if (*index < 0 || *index >= num_args) { - return errors::InvalidArgument("Invalid ", n.type_string(), " number ", - *index); - } - return Status::OK(); -} - -// Returns the data type of the destination of an edge. -DataType EdgeType(const Edge* edge) { - return edge->dst()->input_type(edge->dst_input()); -} - -// Adds the control inputs of `node` to `*deps`. -void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { - for (const Edge* edge : node.in_edges()) { - if (edge->IsControlEdge()) { - deps->insert(edge->src()); - } - } -} - -// Adds the control outputs of `node` to `*deps`. -void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) { - for (const Edge* edge : node.out_edges()) { - if (edge->IsControlEdge()) { - deps->insert(edge->dst()); - } - } -} - -// Rewrite function to be passed to EncapsulateSubgraphsInFunctions that sorts -// the arguments into the order expected by XlaLaunch computations: -// 1) arguments -// 2) resource variable arguments -// See the documentation of EncapsulateSubgraphsInFunctions for the meaning -// of the arguments. -// -// TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. -Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, - std::unique_ptr<Graph>* graph_ptr, - std::vector<int>* input_permutation, - std::vector<int>* output_permutation, - NodeDef* call_def) { - Graph* graph = graph_ptr->get(); - const int num_args = input_permutation->size(); - const int num_retvals = output_permutation->size(); - - std::vector<Node*> args; - std::vector<Node*> retvals; - args.reserve(num_args); - retvals.reserve(num_retvals); - for (Node* n : graph->nodes()) { - if (n->type_string() == "_Arg") { - // Check if this is a guaranteed constant. - if (is_guaranteed_constant(*n)) { - return errors::InvalidArgument( - "Guaranteed constants are not supported (", n->name(), ")"); - } - args.push_back(n); - } else if (n->type_string() == "_Retval") { - retvals.push_back(n); - } - } - - if (std::find(args.begin(), args.end(), nullptr) != args.end()) { - return errors::InvalidArgument("Missing or non-consecutive arguments"); - } - - // Reorders the arguments. - std::sort(args.begin(), args.end(), [&](Node* a, Node* b) { - // Non-resources appear before resources - bool a_is_resource = (a->output_type(0) == DT_RESOURCE); - bool b_is_resource = (b->output_type(0) == DT_RESOURCE); - // Uses the name as a tiebreaker so the output is deterministic. - StringPiece a_name(a->name()); - StringPiece b_name(b->name()); - return std::tie(a_is_resource, a_name) < std::tie(b_is_resource, b_name); - }); - - // Sorts the retvals by name so the order is deterministic. - std::sort(retvals.begin(), retvals.end(), - [](Node* a, Node* b) { return a->name() < b->name(); }); - - // Computes the permutation to produce the correct argument order, and update - // the argument indices. - int variable_start_index = num_args; - for (int i = 0; i < num_args; ++i) { - int index; - TF_RETURN_IF_ERROR(GetIndexAttr(*args[i], num_args, &index)); - if (args[i]->output_type(0) == DT_RESOURCE && - variable_start_index == num_args) { - variable_start_index = i; - } - (*input_permutation)[index] = i; - args[i]->AddAttr("index", i); - } - VLOG(4) << "variable_start_index: " << variable_start_index; - - // Computes the permutation to produce the correct retval order, and update - // the argument indices. - for (int i = 0; i < num_retvals; ++i) { - int index; - TF_RETURN_IF_ERROR(GetIndexAttr(*retvals[i], num_retvals, &index)); - (*output_permutation)[index] = i; - retvals[i]->AddAttr("index", i); - } - - AddNodeAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, call_def->name(), - call_def); - AddNodeAttr("_variable_start_index", variable_start_index, call_def); - - // Uniquify the function name. - GraphDef gdef; - graph->ToGraphDef(&gdef); - - // Before serialization, sort each node's control inputs to achieve - // determinism. Sorting control inputs could help (but not necessarily) create - // a deterministic serialization and fingerprint. Other sources of - // nondeterminism include unstable node ordering. - SortControlInputs(&gdef); - // Fingerprint the function. - // Nondeterminism in serialization would not lead to incorrect results, but - // may cause spurious cache misses. DeterministicSerialization is a - // best-effort deterministic serialization. - string serialized; - TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized)); - uint64 fingerprint = Fingerprint64(serialized); - LOG(INFO) << "Subgraph fingerprint:" << fingerprint; - call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint)); - return Status::OK(); -} - -} // namespace - -/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( - std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) { - // Check for undeclared outputs before Encapsulation, so we can give a better - // error message. - // TODO(phawkins): merge this with the encapsulation code to avoid the extra - // O(n) pass over the edges. - for (const Edge* e : (*graph)->edges()) { - if (!e->IsControlEdge() && - e->src()->attrs().Find(kXlaClusterAttr) != nullptr && - e->dst()->attrs().Find(kXlaClusterAttr) == nullptr && - e->dst()->type_string() != kXlaClusterOutput) { - return errors::InvalidArgument( - "Undeclared output of XLA computation. A common cause of this error " - "is variable initializers that depend on the XLA computation. Edge: ", - e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":", - e->dst_input()); - } - } - - auto output = absl::make_unique<Graph>((*graph)->op_registry()); - TF_RETURN_WITH_CONTEXT_IF_ERROR( - EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, "", **graph, RewriteSubgraph, - /*reuse_existing_functions=*/true, &output, flib_def), - "EncapsulateXlaComputationsPass failed"); - graph->swap(output); - return Status::OK(); -} - -/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( - Graph* graph) { - // Finds all of the XlaLaunch function calls, to avoid mutating the graph - // while iterating. - std::vector<Node*> launch_nodes; - for (Node* n : graph->nodes()) { - string name; - if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { - launch_nodes.push_back(n); - } - } - - // Replaces each launch function call together with its neighboring - // XlaClusterOutput nodes with a XlaLaunch node. - for (Node* launch : launch_nodes) { - int variable_start_index; - TF_RETURN_IF_ERROR(GetNodeAttr(launch->attrs(), "_variable_start_index", - &variable_start_index)); - - std::vector<const Edge*> in_edges; - TF_RETURN_IF_ERROR(launch->input_edges(&in_edges)); - - const int num_inputs = in_edges.size(); - const int num_variables = num_inputs - variable_start_index; - const int num_args = variable_start_index; - - VLOG(4) << "Launch node '" << launch->name() << "'" - << " input edges: " << in_edges.size() << " num_args: " << num_args - << " num_variables: " << num_variables; - - std::vector<Node*> nodes_to_remove = {launch}; - - // Data and control inputs to the new XlaLaunch node. - std::vector<std::pair<Node*, int>> data_inputs(num_inputs); - gtl::FlatSet<Node*> control_inputs; - DataTypeVector arg_types(num_args); - - AddControlInputs(*launch, &control_inputs); - - for (int i = 0; i < num_args; ++i) { - const Edge* edge = in_edges[i]; - data_inputs[i] = {edge->src(), edge->src_output()}; - arg_types[i] = EdgeType(edge); - } - - // Appends the variable inputs. - for (int i = 0; i < num_variables; ++i) { - int pos = variable_start_index + i; - const Edge* edge = in_edges[pos]; - data_inputs[pos] = {edge->src(), edge->src_output()}; - } - - // Outputs. - const int num_outputs = launch->output_types().size(); - gtl::FlatSet<Node*> control_outputs; - std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs); - DataTypeVector output_types(num_outputs); - - for (const Edge* le : launch->out_edges()) { - if (le->IsControlEdge()) { - control_outputs.insert(le->dst()); - } else { - TF_RET_CHECK(le->src_output() < num_outputs); - Node* output_node = le->dst(); - - TF_RET_CHECK(output_node->type_string() == kXlaClusterOutput) - << le->DebugString(); - nodes_to_remove.push_back(output_node); - - for (const Edge* oe : output_node->out_edges()) { - TF_RET_CHECK(!oe->IsControlEdge()); - data_outputs[le->src_output()].push_back( - {oe->dst(), oe->dst_input()}); - } - output_types[le->src_output()] = output_node->input_type(0); - - AddControlOutputs(*output_node, &control_outputs); - } - } - - NodeDef def; - def.set_name(launch->name()); - - // Target the XLA CPU/GPU backends. - VLOG(2) << "Replacing with XlaLaunch"; - def.set_op("XlaLaunch"); - AddNodeAttr("Tconstants", DataTypeVector{}, &def); - AddNodeAttr("Targs", arg_types, &def); - AddNodeAttr("Nresources", num_variables, &def); - AddNodeAttr("Tresults", output_types, &def); - NameAttrList function; - function.set_name(launch->type_string()); - AddNodeAttr("function", function, &def); - - for (Node* node : nodes_to_remove) { - VLOG(2) << "Deleting node " << node->DebugString(); - // Ensure that we do not attempt to add control edges to nodes that are - // deleted. - control_inputs.erase(node); - control_outputs.erase(node); - graph->RemoveNode(node); - } - - Status status; - Node* xla_launch = graph->AddNode(def, &status); - if (!status.ok()) { - return status; - } - for (int i = 0; i < data_inputs.size(); ++i) { - graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, - i); - } - for (Node* n : control_inputs) { - graph->AddControlEdge(n, xla_launch); - } - for (int i = 0; i < data_outputs.size(); ++i) { - for (const auto& successor : data_outputs[i]) { - graph->AddEdge(xla_launch, i, successor.first, successor.second); - } - } - for (Node* n : control_outputs) { - graph->AddControlEdge(xla_launch, n); - } - } - return Status::OK(); -} - -Status EncapsulateXlaComputationsPass::Run( - const GraphOptimizationPassOptions& options) { - VLOG(1) << "EncapsulateXlaComputations(): " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_before", - **options.graph, options.flib_def); - - TF_RETURN_IF_ERROR(Encapsulate(options.graph, options.flib_def)); - VLOG(1) << "EncapsulateXlaComputations() half-way: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_halfway", - **options.graph, options.flib_def); - - TF_RETURN_IF_ERROR(BuildXlaLaunchOps(options.graph->get())); - VLOG(1) << "EncapsulateXlaComputations() finished: " - << dump_graph::DumpGraphToFile("encapsulate_xla_computations_after", - **options.graph, options.flib_def); - return Status::OK(); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h deleted file mode 100644 index c8bb4dc114..0000000000 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Rewrites computations generated by the xla.compile() Python code into -// XlaLaunch nodes. -// -// xla.compile() does two main things: -// a) marks operators that make up a XLA computation with the attribute -// _xla_compile_id=XYZ, where XYZ is a unique key. -// b) adds XlaClusterOutput nodes to represent outputs of the computation. -// These nodes are not marked with the _xla_compile_id attribute. - -#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ -#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ - -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/platform/env.h" - -namespace tensorflow { - -// Encapsulates nodes marked with the _xla_compile_id attribute into -// XlaLaunch operators. -class EncapsulateXlaComputationsPass : public GraphOptimizationPass { - public: - static const char* const kXlaClusterAttr; // _xla_compile_id - - Status Run(const GraphOptimizationPassOptions& options) override; - - // The following methods are public only for unit tests. - - // This pass has two stages: - // a) first, we call EncapsulateSubgraphsPass to encapsulate all nodes - // marked with the same _xla_compile_id attribute into functions. These - // functions contain the computations to be passed to XlaLaunch. During - // encapsulation, we sort the arguments into the order expected by - // XlaLaunch. - static Status Encapsulate(std::unique_ptr<Graph>* graph, - FunctionLibraryDefinition* flib_def); - - // b) we rewrite the function calls generated in phase (a) into XlaLaunch - // operators. We also convert the XlaClusterOutput output nodes of the - // function call into the outputs of the XlaLaunch operator. - static Status BuildXlaLaunchOps(Graph* graph); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_XLA_COMPUTATIONS_PASS_H_ diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc deleted file mode 100644 index f643fb0cfe..0000000000 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ /dev/null @@ -1,346 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" - -#include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/resource_variable_ops.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" -#include "tensorflow/compiler/tf2xla/test_util.h" -#include "tensorflow/core/framework/graph_to_functiondef.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/equal_graph_def.h" -#include "tensorflow/core/util/ptr_util.h" - -namespace tensorflow { - -static std::unique_ptr<Graph> MakeOuterGraph( - const FunctionLibraryDefinition& flib_def, const string& function) { - Scope scope = Scope::NewRootScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib_def.ToProto())); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - NodeDef def; - TF_CHECK_OK( - NodeDefBuilder("launch0", function, &flib_def) - .Input(a.node()->name(), 0, DT_INT32) - .Input(b.node()->name(), 0, DT_FLOAT) - .Input(c.node()->name(), 0, DT_INT32) - .Input(d.node()->name(), 0, DT_FLOAT) - .Input(u.node()->name(), 0, DT_RESOURCE) - .Input(v.node()->name(), 0, DT_RESOURCE) - .Input(w.node()->name(), 0, DT_RESOURCE) - .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") - .Attr("_variable_start_index", 4) - .Finalize(&def)); - - Status status; - Node* launch = scope.graph()->AddNode(def, &status); - TF_CHECK_OK(status); - TF_CHECK_OK(scope.DoShapeInference(launch)); - scope.graph()->AddEdge(a.node(), 0, launch, 0); - scope.graph()->AddEdge(b.node(), 0, launch, 1); - scope.graph()->AddEdge(c.node(), 0, launch, 2); - scope.graph()->AddEdge(d.node(), 0, launch, 3); - scope.graph()->AddEdge(u.node(), 0, launch, 4); - scope.graph()->AddEdge(v.node(), 0, launch, 5); - scope.graph()->AddEdge(w.node(), 0, launch, 6); - - auto out0 = - ops::XlaClusterOutput(scope.WithOpName("Out0"), Output(launch, 0)); - auto out1 = - ops::XlaClusterOutput(scope.WithOpName("Out1"), Output(launch, 1)); - auto out2 = - ops::XlaClusterOutput(scope.WithOpName("Out2"), Output(launch, 2)); - auto out3 = - ops::XlaClusterOutput(scope.WithOpName("Out3"), Output(launch, 3)); - - auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); - auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); - auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); - auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); - auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); - auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); - - std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -// Makes an encapsulate body graph for use in tests. -static std::unique_ptr<Graph> MakeBodyGraph() { - Scope scope = Scope::NewRootScope().ExitOnError(); - - auto arg0 = ops::_Arg(scope.WithOpName("a_0_arg"), DT_INT32, 0); - auto arg1 = ops::_Arg(scope.WithOpName("b_0_arg"), DT_FLOAT, 1); - auto arg2 = ops::_Arg(scope.WithOpName("c_0_arg"), DT_INT32, 2); - auto arg3 = ops::_Arg(scope.WithOpName("d_0_arg"), DT_FLOAT, 3); - - auto arg4 = ops::_Arg(scope.WithOpName("u_0_arg"), DT_RESOURCE, 4); - auto arg5 = ops::_Arg(scope.WithOpName("v_0_arg"), DT_RESOURCE, 5); - auto arg6 = ops::_Arg(scope.WithOpName("w_0_arg"), DT_RESOURCE, 6); - - auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); - }; - - auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); - - auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); - add_attrs(read_u.node()); - auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); - add_attrs(read_v.node()); - auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), arg6, DT_FLOAT); - add_attrs(read_w.node()); - - auto e = ops::Add(scope.WithOpName("E"), arg0, arg2); - add_attrs(e.node()); - auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); - add_attrs(f.node()); - auto g = ops::Add(scope.WithOpName("G"), f, arg3); - add_attrs(g.node()); - - auto out0 = ops::_Retval(scope.WithOpName("b_identity_0_retval_RetVal"), - b_identity, 0); - auto out1 = ops::_Retval(scope.WithOpName("e_0_retval_RetVal"), e, 1); - auto out2 = ops::_Retval(scope.WithOpName("g_0_retval_RetVal"), g, 2); - auto out3 = - ops::_Retval(scope.WithOpName("readu_0_retval_RetVal"), read_u, 3); - - std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -TEST(EncapsulateXlaComputations, DeterministicEncapsulate) { - // Test that control edge insertion order doesn't affect the cache key - // (cluster name) generated by TPU encapsulate pass. - auto get_serialized_graph = [](bool control_input_reversed, - bool operand_reversed) -> string { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - std::unique_ptr<Graph> graph(new Graph(&flib_def)); - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto a0 = ops::Placeholder(scope.WithOpName("A0"), DT_INT32); - auto a1 = ops::Placeholder(scope.WithOpName("A1"), DT_INT32); - - ops::Add e = operand_reversed ? ops::Add(scope.WithOpName("E"), a0, a1) - : ops::Add(scope.WithOpName("E"), a1, a0); - - auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, - "launch0"); - }; - add_attrs(e.node()); - - TF_CHECK_OK(scope.ToGraph(graph.get())); - auto get_node_in_graph = [&graph](Node* node) { - return graph->FindNodeId(node->id()); - }; - // Insert control edge in different order. The order should not affect - // the encapsulated or serialized graph. - if (!control_input_reversed) { - graph->AddControlEdge(get_node_in_graph(a0.node()), - get_node_in_graph(e.node()), true); - graph->AddControlEdge(get_node_in_graph(a1.node()), - get_node_in_graph(e.node()), true); - } else { - graph->AddControlEdge(get_node_in_graph(a1.node()), - get_node_in_graph(e.node()), true); - graph->AddControlEdge(get_node_in_graph(a0.node()), - get_node_in_graph(e.node()), true); - } - } - TF_CHECK_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - GraphDef gdef; - graph->ToGraphDef(&gdef); - // Before serialization, sort control inputs first to remove - // nondeterminism. - SortControlInputs(&gdef); - string serialized; - SerializeToStringDeterministic(gdef, &serialized); - return serialized; - }; - - // Changing the order of control input shouldn't affect the graph generated. - EXPECT_EQ(get_serialized_graph(/*control_input_reversed=*/true, - /*operand_reversed=*/false), - get_serialized_graph(/*control_input_reversed=*/false, - /*operand_reversed=*/false)); - - // Changing the order of data input should affect the graph generated. - EXPECT_NE(get_serialized_graph(/*control_input_reversed=*/false, - /*operand_reversed=*/true), - get_serialized_graph(/*control_input_reversed=*/false, - /*operand_reversed=*/false)); -} - -TEST(EncapsulateXlaComputations, Encapsulate) { - FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); - std::unique_ptr<Graph> graph(new Graph(&flib_def)); - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - auto add_attrs = [](Node* node) { - node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); - }; - - auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); - add_attrs(b_identity.node()); - - auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), u, DT_FLOAT); - add_attrs(read_u.node()); - auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), v, DT_FLOAT); - add_attrs(read_v.node()); - auto read_w = ops::ReadVariableOp(scope.WithOpName("ReadW"), w, DT_FLOAT); - add_attrs(read_w.node()); - - auto e = ops::Add(scope.WithOpName("E"), a, c); - add_attrs(e.node()); - auto f = ops::Add(scope.WithOpName("F"), read_v, read_w); - add_attrs(f.node()); - auto g = ops::Add(scope.WithOpName("G"), f, d); - add_attrs(g.node()); - - auto out0 = ops::XlaClusterOutput(scope.WithOpName("Out0"), b_identity); - auto out1 = ops::XlaClusterOutput(scope.WithOpName("Out1"), e); - auto out2 = ops::XlaClusterOutput(scope.WithOpName("Out2"), g); - auto out3 = ops::XlaClusterOutput(scope.WithOpName("Out3"), read_u); - - auto consumer0_a = ops::Identity(scope.WithOpName("consumer0_a"), out0); - auto consumer0_b = ops::Identity(scope.WithOpName("consumer0_b"), out0); - auto consumer0_c = ops::Identity(scope.WithOpName("consumer0_c"), out0); - auto consumer1 = ops::Identity(scope.WithOpName("consumer1"), out1); - auto consumer2 = ops::Identity(scope.WithOpName("consumer2"), out2); - auto consumer3 = ops::Identity(scope.WithOpName("consumer3"), out3); - TF_ASSERT_OK(scope.ToGraph(graph.get())); - } - - std::unique_ptr<Graph> graph_copy(new Graph(&flib_def)); - CopyGraph(*graph, graph_copy.get()); - - TF_ASSERT_OK(EncapsulateXlaComputationsPass::Encapsulate(&graph, &flib_def)); - - std::unordered_map<string, Node*> index = BuildNodeIndex(*graph); - string function = index.at("launch0")->type_string(); - - // Tests the outer graph is as expected. - { - std::unique_ptr<Graph> outer = MakeOuterGraph(flib_def, function); - GraphDef expected_def; - outer->ToGraphDef(&expected_def); - - GraphDef actual_def; - graph->ToGraphDef(&actual_def); - TF_EXPECT_GRAPH_EQ_INTERNAL(expected_def, actual_def); - } - - // Tests the encapsulated body graph is as expected. - { - std::unique_ptr<Graph> body = MakeBodyGraph(); - GraphDef expected_body_def; - body->ToGraphDef(&expected_body_def); - - InstantiationResultForTest result; - TF_EXPECT_OK(InstantiateFunctionForTest(function, flib_def, &result)); - - EXPECT_EQ((DataTypeVector{DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, - DT_RESOURCE, DT_RESOURCE, DT_RESOURCE}), - result.arg_types); - EXPECT_EQ((DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}), - result.ret_types); - TF_EXPECT_GRAPH_EQ(expected_body_def, result.gdef); - } - - // Encapsulates the same computation again, verifies we reuse the same - // function. Encapsulation should be deterministic to avoid recompilation. - TF_ASSERT_OK( - EncapsulateXlaComputationsPass::Encapsulate(&graph_copy, &flib_def)); - std::unordered_map<string, Node*> index_copy = BuildNodeIndex(*graph_copy); - string function_copy = index_copy.at("launch0")->type_string(); - EXPECT_EQ(function, function_copy); -} - -TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { - std::unique_ptr<Graph> body_graph = MakeBodyGraph(); - FunctionDefLibrary flib; - TF_ASSERT_OK(GraphToFunctionDef(*body_graph, "launch0", flib.add_function())); - - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - - std::unique_ptr<Graph> graph = MakeOuterGraph(flib_def, "launch0"); - TF_ASSERT_OK(EncapsulateXlaComputationsPass::BuildXlaLaunchOps(graph.get())); - - Scope scope = Scope::DisabledShapeInferenceScope().ExitOnError(); - TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); - - auto a = ops::Placeholder(scope.WithOpName("A"), DT_INT32); - auto b = ops::Placeholder(scope.WithOpName("B"), DT_FLOAT); - auto c = ops::Placeholder(scope.WithOpName("C"), DT_INT32); - auto d = ops::Placeholder(scope.WithOpName("D"), DT_FLOAT); - auto u = ops::Placeholder(scope.WithOpName("U"), DT_RESOURCE); - auto v = ops::Placeholder(scope.WithOpName("V"), DT_RESOURCE); - auto w = ops::Placeholder(scope.WithOpName("W"), DT_RESOURCE); - - NameAttrList function; - function.set_name("launch0"); - auto launch = ops::XlaLaunch( - scope.WithOpName("launch0"), std::initializer_list<Input>{}, - std::initializer_list<Input>{a, b, c, d}, - std::initializer_list<Input>{u, v, w}, - DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); - - auto consumer0_a = - ops::Identity(scope.WithOpName("consumer0_a"), launch.results[0]); - auto consumer0_b = - ops::Identity(scope.WithOpName("consumer0_b"), launch.results[0]); - auto consumer0_c = - ops::Identity(scope.WithOpName("consumer0_c"), launch.results[0]); - auto consumer1 = - ops::Identity(scope.WithOpName("consumer1"), launch.results[1]); - auto consumer2 = - ops::Identity(scope.WithOpName("consumer2"), launch.results[2]); - auto consumer3 = - ops::Identity(scope.WithOpName("consumer3"), launch.results[3]); - - GraphDef expected_def; - TF_ASSERT_OK(scope.ToGraphDef(&expected_def)); - - GraphDef actual_def; - graph->ToGraphDef(&actual_def); - TF_EXPECT_GRAPH_EQ(expected_def, actual_def); -} - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 315fcb2fa7..c37b6112cc 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -15,19 +15,12 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { -// EncapsulateXlaComputationsPass rewrites computations generated by the -// xla.compile() Python code into XlaLaunch nodes. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26, - EncapsulateXlaComputationsPass); - -// The following POST_REWRITE passes support auto-clustering to enable XLA. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 1a29c3caab..f2473d98ff 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -13,14 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { -using shape_inference::InferenceContext; - REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") @@ -36,19 +32,4 @@ REGISTER_OP("XlaLaunch") .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); -REGISTER_OP("XlaClusterOutput") - .Input("input: T") - // Note: when replication is supported, this op will have N outputs. - .Output("outputs: T") - .Attr("T: type") - .SetShapeFn([](InferenceContext* c) { - for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->input(0)); - } - return Status::OK(); - }) - .Doc( - "Operator that connects the output of an XLA computation to other " - "consumer graph nodes."); - } // namespace tensorflow |