diff options
author | Peter Hawkins <phawkins@google.com> | 2017-04-28 12:01:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-28 13:24:14 -0700 |
commit | 30fdc1a3b36cfd4c8859be9d4f4ef1d951b59c21 (patch) | |
tree | 7d78949b7d8a6e761a38095a2309c9e7bf83f266 /tensorflow | |
parent | 5b85e8831367d4d02e9bb81d1d7475629b6255a7 (diff) |
[TF:XLA] Improve constant folding. Supply a null partition_device so the constant folder will constant-fold DT_INT32 operators.
Tensorflow core: Add support for constant folding of nodes with control dependencies.
Mark candidate sampling ops as stateful; the constant-folder was now constant-folding them where it should not.
Change: 154572578
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/constant_folding.cc | 140 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/constant_folding_test.cc | 43 | ||||
-rw-r--r-- | tensorflow/core/ops/candidate_sampling_ops.cc | 6 |
5 files changed, 133 insertions, 64 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 33b4a43aa1..155fc58577 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -115,11 +115,12 @@ Status XlaCompiler::CompileFunction( } // Optimize the graph before running the compiler. - // TODO(pbar): The constant folder currently does not simplify int32 - // operations for devices other than CPU. OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); - OptimizeGraph(flr, &graph); + optimizer.Optimize(flr, flr->env(), /*device=*/nullptr, &graph); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile( diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 13f666f394..79fd7ec01e 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -2224,6 +2224,7 @@ tf_cc_test( "//tensorflow/core/kernels:bcast_ops", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:concat_op", + "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:immutable_constant_op", "//tensorflow/core/kernels:matmul_op", diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 5b604189e1..8fa61d098e 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" @@ -44,6 +45,9 @@ namespace { bool IsConstantFoldable(const Node* n, const std::function<bool(const Node*)>& consider) { + if (n->IsConstant()) { + return true; + } if (n->op_def().is_stateful()) { return false; } @@ -78,47 +82,60 @@ bool IsConstantFoldable(const Node* n, return true; } -// Returns the constant foldable nodes in `nodes_result` in data flow order. -void FindConstantFoldableNodes(const Graph* graph, - const FunctionLibraryDefinition* flib_def, - ConstantFoldingOptions opts, - std::vector<Node*>* nodes_result) { - std::set<const Node*> node_set; - std::vector<Node*>& nodes = *nodes_result; +// Returns the constant foldable nodes in `nodes` in topological order. +// Populates `constant_control_deps` with the non-constant control depedencies +// of each constant node. +void FindConstantFoldableNodes( + const Graph* graph, ConstantFoldingOptions opts, std::vector<Node*>* nodes, + std::unordered_map<const Node*, gtl::FlatSet<Node*>>* + constant_control_deps) { bool internal_node_inserted = false; // Walk the nodes in data flow order - ReverseDFS(*graph, nullptr, [&nodes, &node_set, &internal_node_inserted, opts, - flib_def](Node* n) { - if (n->IsConstant()) { - // Constants with no control inputs (except from _SOURCE node) - // are definitely constant foldable. - if (n->in_edges().size() == 0 || - (n->in_edges().size() == 1 && - (*n->in_edges().begin())->src()->IsSource())) { - node_set.insert(n); - nodes.push_back(n); - } - } else if (IsConstantFoldable(n, opts.consider)) { - // Check whether the set of this node's in_nodes is completely - // included in the set of constant foldable nodes. If true, - // then this node is also constant foldable. - bool all_parents_constant = true; - for (const Node* parent : n->in_nodes()) { - if (node_set.count(parent) == 0 && !parent->IsSource()) { - all_parents_constant = false; - break; + ReverseDFS( + *graph, nullptr, + [nodes, constant_control_deps, &internal_node_inserted, opts](Node* n) { + if (IsConstantFoldable(n, opts.consider)) { + // A node is constant provided all of its non-control + // incoming Tensors come from constant nodes. + // + // We allow control dependencies from non-constant nodes to constant + // nodes, but to preserve the graph structure we must transfer the + // control dependency onto any constant replacement. + bool all_parents_constant = true; + for (const Edge* in : n->in_edges()) { + // Allows non-constant -> constant control edges. + if (!in->IsControlEdge() && + constant_control_deps->count(in->src()) == 0) { + all_parents_constant = false; + break; + } + } + if (all_parents_constant) { + gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n]; + for (const Edge* e : n->in_edges()) { + if (constant_control_deps->count(e->src()) == 0) { + if (!e->src()->IsSource()) { + control_deps.insert(e->src()); + } + } else { + // If the parent is constant, add all of its transitive control + // deps. + const gtl::FlatSet<Node*>& parent_deps = + (*constant_control_deps)[e->src()]; + control_deps.insert(parent_deps.begin(), parent_deps.end()); + } + } + nodes->push_back(n); + if (!n->IsConstant()) { + internal_node_inserted = true; + } + } } - } - if (all_parents_constant) { - node_set.insert(n); - nodes.push_back(n); - internal_node_inserted = true; - } - } - }); + }); // If we have inserted just leaf level nodes, then there is nothing to fold. if (!internal_node_inserted) { - nodes.clear(); + nodes->clear(); + constant_control_deps->clear(); } } @@ -134,23 +151,21 @@ Graph* GetConstantGraph(const Graph* orig_graph, std::map<NodeAndOutput, Node*>* tensors_to_fetch) { Graph* constant_graph = new Graph(orig_graph->op_registry()); std::unordered_map<Node*, Node*> node_map; - std::set<Node*> already_added; - already_added.insert(constant_graph->source_node()); - already_added.insert(constant_graph->sink_node()); node_map[orig_graph->source_node()] = constant_graph->source_node(); node_map[orig_graph->sink_node()] = constant_graph->sink_node(); for (Node* n : nodes) { Node* added = constant_graph->CopyNode(n); node_map[n] = added; - already_added.insert(added); for (const Edge* in_edge : n->in_edges()) { - Node* in = in_edge->src(); - CHECK_GT(node_map.count(in), size_t{0}) << n->DebugString() << " <-" - << in->DebugString(); - CHECK_GT(already_added.count(node_map[in]), size_t{0}) - << in->DebugString(); - constant_graph->AddEdge(node_map[in], in_edge->src_output(), added, - in_edge->dst_input()); + // Don't copy control edges to the constant graph. + if (!in_edge->IsControlEdge()) { + Node* in = in_edge->src(); + auto it = node_map.find(in); + CHECK(it != node_map.end()) + << n->DebugString() << " <-" << in->DebugString(); + constant_graph->AddEdge(it->second, in_edge->src_output(), added, + in_edge->dst_input()); + } } } @@ -176,8 +191,11 @@ int64 UniqueConstantId() { // the value supplied in 'constant'. 'partition_device', if non-null // is the device where the graph executes. Returns true if the // replacement was successful, false otherwise. +// 'control_deps' is the set of nodes that should be control predecessors of the +// new constant node. bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, - NodeAndOutput tensor, const Tensor& constant) { + NodeAndOutput tensor, const Tensor& constant, + const gtl::FlatSet<Node*>& control_deps) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY @@ -241,8 +259,8 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, return false; } - VLOG(1) << "Replacing " << tensor.first->DebugString() - << " :: " << tensor.second << " with a constant"; + VLOG(1) << "Replacing " << tensor.first->name() << " :: " << tensor.second + << " with a constant"; if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; @@ -251,7 +269,13 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device, graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); } - graph->AddEdge(graph->source_node(), -1, constant_node, -1); + if (control_deps.empty()) { + graph->AddControlEdge(graph->source_node(), constant_node); + } else { + for (Node* node : control_deps) { + graph->AddControlEdge(node, constant_node); + } + } if (partition_device) { constant_node->set_assigned_device_name(partition_device->name()); } @@ -265,13 +289,10 @@ Status ConstantFold(const ConstantFoldingOptions& opts, Device* partition_device, Graph* graph, bool* was_mutated) { DumpGraph("Before", graph); - const FunctionLibraryDefinition* flib_def = nullptr; - if (function_library) { - flib_def = function_library->GetFunctionLibraryDefinition(); - } - std::vector<Node*> constant_foldable_nodes; - FindConstantFoldableNodes(graph, flib_def, opts, &constant_foldable_nodes); + std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps; + FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes, + &constant_control_deps); if (constant_foldable_nodes.empty()) { VLOG(1) << "No constant foldable nodes found"; *was_mutated = false; @@ -324,8 +345,11 @@ Status ConstantFold(const ConstantFoldingOptions& opts, // original graph with those constants. int32 num_nodes_replaced = 0; for (size_t c = 0; c < outputs.size(); ++c) { + const gtl::FlatSet<Node*>& control_deps = + constant_control_deps[tensors_to_replace[c].first]; if (ReplaceTensorWithConstant(graph, partition_device, - tensors_to_replace[c], outputs[c])) { + tensors_to_replace[c], outputs[c], + control_deps)) { ++num_nodes_replaced; } } diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index a4612aba72..e45490ee06 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -230,14 +230,14 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) { Node* b1_ident = index.at("b1_ident"); // 0th output of b should have been folded. - EXPECT_EQ(1, b0->num_inputs()); + ASSERT_EQ(1, b0->num_inputs()); ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2}); // 1st output of b should still be b1_ident. However, b1_ident's input must // have been replaced with a constant. - EXPECT_EQ(1, b1->num_inputs()); + ASSERT_EQ(1, b1->num_inputs()); EXPECT_EQ(*(b1->in_nodes().begin()), b1_ident); - EXPECT_EQ(1, b1_ident->num_inputs()); + ASSERT_EQ(1, b1_ident->num_inputs()); ExpectNodeEqual<int>(*(b1_ident->in_nodes().begin()), {}, {0}); } @@ -325,6 +325,43 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) { EXPECT_FALSE(was_mutated); } +TEST_F(ConstantFoldingTest, ControlDependencies) { + Graph g(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + auto c0 = ops::Const<int>(s, 1); + auto recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1", "sender", + 0, "receiver"); + auto c1 = ops::Const<int>(s.WithControlDependencies(recv1), 2); + auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender", + 0, "receiver"); + auto c2 = ops::Const<int>(s.WithControlDependencies(recv2), 3); + auto add = ops::Add(s.WithControlDependencies(c2), c0, c1); + auto send = + ops::_Send(s.WithOpName("send"), add, "send", "sender", 0, "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + } + bool was_mutated; + TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(), + nullptr, &g, &was_mutated)); + EXPECT_TRUE(was_mutated); + + std::unordered_map<string, Node*> index = NodeNameIndex(g); + Node* recv1 = index.at("recv1"); + Node* recv2 = index.at("recv2"); + Node* send = index.at("send"); + + ASSERT_EQ(1, send->num_inputs()); + Node* p = *(send->in_nodes().begin()); + ExpectNodeEqual<int>(p, {3}, {}); + + ASSERT_EQ(2, p->in_edges().size()); + for (const Edge* e : p->in_edges()) { + EXPECT_TRUE(e->IsControlEdge()); + EXPECT_TRUE(e->src() == recv1 || e->src() == recv2) << e->src()->name(); + } +} + namespace { const char kTestMemRegionName[] = "test://test"; diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc index 037c393574..945e0b068a 100644 --- a/tensorflow/core/ops/candidate_sampling_ops.cc +++ b/tensorflow/core/ops/candidate_sampling_ops.cc @@ -55,6 +55,7 @@ REGISTER_OP("UniformCandidateSampler") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(CandidateSamplerShapeFn) + .SetIsStateful() .Doc(R"doc( Generates labels for candidate sampling with a uniform distribution. @@ -103,6 +104,7 @@ REGISTER_OP("LogUniformCandidateSampler") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(CandidateSamplerShapeFn) + .SetIsStateful() .Doc(R"doc( Generates labels for candidate sampling with a log-uniform distribution. @@ -152,6 +154,7 @@ REGISTER_OP("LearnedUnigramCandidateSampler") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(CandidateSamplerShapeFn) + .SetIsStateful() .Doc(R"doc( Generates labels for candidate sampling with a learned unigram distribution. @@ -200,6 +203,7 @@ REGISTER_OP("ThreadUnsafeUnigramCandidateSampler") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(CandidateSamplerShapeFn) + .SetIsStateful() .Doc(R"doc( Generates labels for candidate sampling with a learned unigram distribution. @@ -254,6 +258,7 @@ REGISTER_OP("FixedUnigramCandidateSampler") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(CandidateSamplerShapeFn) + .SetIsStateful() .Doc(R"doc( Generates labels for candidate sampling with a learned unigram distribution. @@ -329,6 +334,7 @@ REGISTER_OP("AllCandidateSampler") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(CandidateSamplerShapeFn) + .SetIsStateful() .Doc(R"doc( Generates labels for candidate sampling with a learned unigram distribution. |