aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc')
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc54
1 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index e786d41887..fdd71c6a58 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -60,9 +60,9 @@ const char* const kXlaHostTransferSequencerAttr =
namespace {
-bool AreAllParentsConst(const Node& n,
- const gtl::FlatSet<const Node*>& runtime_const_nodes) {
- if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
+bool AreAllParentsGuaranteedConst(
+ const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
+ if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
return true;
@@ -93,7 +93,8 @@ void MarkGuaranteedConstants(
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
- if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
+ if (AreAllParentsGuaranteedConst(*n,
+ guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
@@ -137,7 +138,7 @@ class Encapsulator {
// Find subgraphs marked with 'group_attribute', and build a new
// subgraph, one for each value of 'group_attribute'.
- Status SplitIntoSubgraphs();
+ Status SplitIntoSubgraphs(FunctionLibraryDefinition* library);
// Build a FunctionDef for each subgraph, and add it 'library'. The values of
// the 'group_attribute' annotations become the function names.
@@ -1136,7 +1137,10 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo(
GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef));
host_compute->AddAttr("shape_inference_graph", inference_graph_name);
host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
- TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ // TODO(sibyl-Aix6ihai): Understand why there are multiple calls to Encapsulator.
+ if (library->Find(inference_graph_name) == nullptr) {
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ }
}
return Status::OK();
}
@@ -1474,7 +1478,7 @@ Status Encapsulator::CopySubgraphEdges(
return Status::OK();
}
-Status Encapsulator::SplitIntoSubgraphs() {
+Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) {
Status s;
// Map from input graph nodes to subgraph nodes.
@@ -1509,6 +1513,15 @@ Status Encapsulator::SplitIntoSubgraphs() {
TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy));
}
+ if (VLOG_IS_ON(1)) {
+ // Dump subgraphs.
+ for (auto& entry : subgraphs_) {
+ dump_graph::DumpGraphToFile(
+ strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first),
+ *entry.second.GetGraph(), library);
+ }
+ }
+
return s;
}
@@ -1932,6 +1945,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
// continue.
TensorShapeProto proto;
context->ShapeHandleToProto(shape, &proto);
+ VLOG(2) << "Node " << src_node->name()
+ << " has known shape: " << proto.DebugString();
if (dummy_node_images.find(src_node) == dummy_node_images.end()) {
dummy_node_images[src_node] =
AddDummyShapedNode(src_node, src_port, control_flow_info,
@@ -1949,6 +1964,8 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
if (VLOG_IS_ON(2)) {
TensorShapeProto proto;
context->ShapeHandleToProto(shape, &proto);
+ VLOG(2) << "Node " << src_node->name()
+ << " has unknown shape: " << proto.DebugString();
}
stack.push_back({src_node, false});
}
@@ -2191,6 +2208,23 @@ Status Encapsulator::FindClusterDependencies() {
}
}
}
+ if (VLOG_IS_ON(2)) {
+ // Print debug information.
+ VLOG(2) << "node_ancestors_map:";
+ for (const auto& node_iter : node_ancestors_map) {
+ VLOG(2) << "\t" << node_iter.first->name() << ": subgraph = '"
+ << node_iter.second.subgraph
+ << "', outside_compilation_cluster = '"
+ << node_iter.second.outside_compilation_cluster
+ << "', ancestor_clusters: "
+ << (node_iter.second.ancestor_clusters.empty() ? "(empty)" : "");
+ for (const auto& cluster_iter : node_iter.second.ancestor_clusters) {
+ VLOG(2) << "\t\tsubgraph = '" << cluster_iter.subgraph
+ << "', outside_compilation_cluster = '"
+ << cluster_iter.outside_compilation_cluster << "'";
+ }
+ }
+ }
return Status::OK();
}
@@ -2398,7 +2432,7 @@ Status EncapsulateSubgraphsInFunctions(
std::move(outside_compilation_attribute),
&graph_in);
TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies());
- TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs());
+ TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs(library));
TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs(
rewrite_subgraph_fn, reuse_existing_functions, library));
@@ -2447,7 +2481,7 @@ Status EncapsulateSubgraphsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateSubgraphsPass::Run";
if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
+ dump_graph::DumpGraphToFile("encapsulate_subgraphs_before", **options.graph,
options.flib_def);
}
@@ -2530,7 +2564,7 @@ Status EncapsulateSubgraphsPass::Run(
"EncapsulateSubgraphsPass failed");
if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
+ dump_graph::DumpGraphToFile("encapsulate_subgraphs_after", *graph_out,
options.flib_def);
}