aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-04-28 12:01:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 13:24:14 -0700
commit30fdc1a3b36cfd4c8859be9d4f4ef1d951b59c21 (patch)
tree7d78949b7d8a6e761a38095a2309c9e7bf83f266 /tensorflow
parent5b85e8831367d4d02e9bb81d1d7475629b6255a7 (diff)
[TF:XLA] Improve constant folding. Supply a null partition_device so the constant folder will constant-fold DT_INT32 operators.
Tensorflow core: Add support for constant folding of nodes with control dependencies. Mark candidate sampling ops as stateful; the constant-folder was now constant-folding them where it should not. Change: 154572578
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc7
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc140
-rw-r--r--tensorflow/core/common_runtime/constant_folding_test.cc43
-rw-r--r--tensorflow/core/ops/candidate_sampling_ops.cc6
5 files changed, 133 insertions, 64 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 33b4a43aa1..155fc58577 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -115,11 +115,12 @@ Status XlaCompiler::CompileFunction(
}
// Optimize the graph before running the compiler.
- // TODO(pbar): The constant folder currently does not simplify int32
- // operations for devices other than CPU.
OptimizerOptions opts;
+ opts.set_do_common_subexpression_elimination(true);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
GraphOptimizer optimizer(opts);
- OptimizeGraph(flr, &graph);
+ optimizer.Optimize(flr, flr->env(), /*device=*/nullptr, &graph);
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile(
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 13f666f394..79fd7ec01e 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2224,6 +2224,7 @@ tf_cc_test(
"//tensorflow/core/kernels:bcast_ops",
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:concat_op",
+ "//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:immutable_constant_op",
"//tensorflow/core/kernels:matmul_op",
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 5b604189e1..8fa61d098e 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/subgraph.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/session_options.h"
@@ -44,6 +45,9 @@ namespace {
bool IsConstantFoldable(const Node* n,
const std::function<bool(const Node*)>& consider) {
+ if (n->IsConstant()) {
+ return true;
+ }
if (n->op_def().is_stateful()) {
return false;
}
@@ -78,47 +82,60 @@ bool IsConstantFoldable(const Node* n,
return true;
}
-// Returns the constant foldable nodes in `nodes_result` in data flow order.
-void FindConstantFoldableNodes(const Graph* graph,
- const FunctionLibraryDefinition* flib_def,
- ConstantFoldingOptions opts,
- std::vector<Node*>* nodes_result) {
- std::set<const Node*> node_set;
- std::vector<Node*>& nodes = *nodes_result;
+// Returns the constant foldable nodes in `nodes` in topological order.
+// Populates `constant_control_deps` with the non-constant control depedencies
+// of each constant node.
+void FindConstantFoldableNodes(
+ const Graph* graph, ConstantFoldingOptions opts, std::vector<Node*>* nodes,
+ std::unordered_map<const Node*, gtl::FlatSet<Node*>>*
+ constant_control_deps) {
bool internal_node_inserted = false;
// Walk the nodes in data flow order
- ReverseDFS(*graph, nullptr, [&nodes, &node_set, &internal_node_inserted, opts,
- flib_def](Node* n) {
- if (n->IsConstant()) {
- // Constants with no control inputs (except from _SOURCE node)
- // are definitely constant foldable.
- if (n->in_edges().size() == 0 ||
- (n->in_edges().size() == 1 &&
- (*n->in_edges().begin())->src()->IsSource())) {
- node_set.insert(n);
- nodes.push_back(n);
- }
- } else if (IsConstantFoldable(n, opts.consider)) {
- // Check whether the set of this node's in_nodes is completely
- // included in the set of constant foldable nodes. If true,
- // then this node is also constant foldable.
- bool all_parents_constant = true;
- for (const Node* parent : n->in_nodes()) {
- if (node_set.count(parent) == 0 && !parent->IsSource()) {
- all_parents_constant = false;
- break;
+ ReverseDFS(
+ *graph, nullptr,
+ [nodes, constant_control_deps, &internal_node_inserted, opts](Node* n) {
+ if (IsConstantFoldable(n, opts.consider)) {
+ // A node is constant provided all of its non-control
+ // incoming Tensors come from constant nodes.
+ //
+ // We allow control dependencies from non-constant nodes to constant
+ // nodes, but to preserve the graph structure we must transfer the
+ // control dependency onto any constant replacement.
+ bool all_parents_constant = true;
+ for (const Edge* in : n->in_edges()) {
+ // Allows non-constant -> constant control edges.
+ if (!in->IsControlEdge() &&
+ constant_control_deps->count(in->src()) == 0) {
+ all_parents_constant = false;
+ break;
+ }
+ }
+ if (all_parents_constant) {
+ gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
+ for (const Edge* e : n->in_edges()) {
+ if (constant_control_deps->count(e->src()) == 0) {
+ if (!e->src()->IsSource()) {
+ control_deps.insert(e->src());
+ }
+ } else {
+ // If the parent is constant, add all of its transitive control
+ // deps.
+ const gtl::FlatSet<Node*>& parent_deps =
+ (*constant_control_deps)[e->src()];
+ control_deps.insert(parent_deps.begin(), parent_deps.end());
+ }
+ }
+ nodes->push_back(n);
+ if (!n->IsConstant()) {
+ internal_node_inserted = true;
+ }
+ }
}
- }
- if (all_parents_constant) {
- node_set.insert(n);
- nodes.push_back(n);
- internal_node_inserted = true;
- }
- }
- });
+ });
// If we have inserted just leaf level nodes, then there is nothing to fold.
if (!internal_node_inserted) {
- nodes.clear();
+ nodes->clear();
+ constant_control_deps->clear();
}
}
@@ -134,23 +151,21 @@ Graph* GetConstantGraph(const Graph* orig_graph,
std::map<NodeAndOutput, Node*>* tensors_to_fetch) {
Graph* constant_graph = new Graph(orig_graph->op_registry());
std::unordered_map<Node*, Node*> node_map;
- std::set<Node*> already_added;
- already_added.insert(constant_graph->source_node());
- already_added.insert(constant_graph->sink_node());
node_map[orig_graph->source_node()] = constant_graph->source_node();
node_map[orig_graph->sink_node()] = constant_graph->sink_node();
for (Node* n : nodes) {
Node* added = constant_graph->CopyNode(n);
node_map[n] = added;
- already_added.insert(added);
for (const Edge* in_edge : n->in_edges()) {
- Node* in = in_edge->src();
- CHECK_GT(node_map.count(in), size_t{0}) << n->DebugString() << " <-"
- << in->DebugString();
- CHECK_GT(already_added.count(node_map[in]), size_t{0})
- << in->DebugString();
- constant_graph->AddEdge(node_map[in], in_edge->src_output(), added,
- in_edge->dst_input());
+ // Don't copy control edges to the constant graph.
+ if (!in_edge->IsControlEdge()) {
+ Node* in = in_edge->src();
+ auto it = node_map.find(in);
+ CHECK(it != node_map.end())
+ << n->DebugString() << " <-" << in->DebugString();
+ constant_graph->AddEdge(it->second, in_edge->src_output(), added,
+ in_edge->dst_input());
+ }
}
}
@@ -176,8 +191,11 @@ int64 UniqueConstantId() {
// the value supplied in 'constant'. 'partition_device', if non-null
// is the device where the graph executes. Returns true if the
// replacement was successful, false otherwise.
+// 'control_deps' is the set of nodes that should be control predecessors of the
+// new constant node.
bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
- NodeAndOutput tensor, const Tensor& constant) {
+ NodeAndOutput tensor, const Tensor& constant,
+ const gtl::FlatSet<Node*>& control_deps) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
// 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
@@ -241,8 +259,8 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
return false;
}
- VLOG(1) << "Replacing " << tensor.first->DebugString()
- << " :: " << tensor.second << " with a constant";
+ VLOG(1) << "Replacing " << tensor.first->name() << " :: " << tensor.second
+ << " with a constant";
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
@@ -251,7 +269,13 @@ bool ReplaceTensorWithConstant(Graph* graph, Device* partition_device,
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
}
- graph->AddEdge(graph->source_node(), -1, constant_node, -1);
+ if (control_deps.empty()) {
+ graph->AddControlEdge(graph->source_node(), constant_node);
+ } else {
+ for (Node* node : control_deps) {
+ graph->AddControlEdge(node, constant_node);
+ }
+ }
if (partition_device) {
constant_node->set_assigned_device_name(partition_device->name());
}
@@ -265,13 +289,10 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
Device* partition_device, Graph* graph, bool* was_mutated) {
DumpGraph("Before", graph);
- const FunctionLibraryDefinition* flib_def = nullptr;
- if (function_library) {
- flib_def = function_library->GetFunctionLibraryDefinition();
- }
-
std::vector<Node*> constant_foldable_nodes;
- FindConstantFoldableNodes(graph, flib_def, opts, &constant_foldable_nodes);
+ std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
+ FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
+ &constant_control_deps);
if (constant_foldable_nodes.empty()) {
VLOG(1) << "No constant foldable nodes found";
*was_mutated = false;
@@ -324,8 +345,11 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
// original graph with those constants.
int32 num_nodes_replaced = 0;
for (size_t c = 0; c < outputs.size(); ++c) {
+ const gtl::FlatSet<Node*>& control_deps =
+ constant_control_deps[tensors_to_replace[c].first];
if (ReplaceTensorWithConstant(graph, partition_device,
- tensors_to_replace[c], outputs[c])) {
+ tensors_to_replace[c], outputs[c],
+ control_deps)) {
++num_nodes_replaced;
}
}
diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc
index a4612aba72..e45490ee06 100644
--- a/tensorflow/core/common_runtime/constant_folding_test.cc
+++ b/tensorflow/core/common_runtime/constant_folding_test.cc
@@ -230,14 +230,14 @@ TEST_F(ConstantFoldingTest, TwoOutputsFoldOneOutput) {
Node* b1_ident = index.at("b1_ident");
// 0th output of b should have been folded.
- EXPECT_EQ(1, b0->num_inputs());
+ ASSERT_EQ(1, b0->num_inputs());
ExpectNodeEqual<int>(*(b0->in_nodes().begin()), {0, 1}, {2});
// 1st output of b should still be b1_ident. However, b1_ident's input must
// have been replaced with a constant.
- EXPECT_EQ(1, b1->num_inputs());
+ ASSERT_EQ(1, b1->num_inputs());
EXPECT_EQ(*(b1->in_nodes().begin()), b1_ident);
- EXPECT_EQ(1, b1_ident->num_inputs());
+ ASSERT_EQ(1, b1_ident->num_inputs());
ExpectNodeEqual<int>(*(b1_ident->in_nodes().begin()), {}, {0});
}
@@ -325,6 +325,43 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) {
EXPECT_FALSE(was_mutated);
}
+TEST_F(ConstantFoldingTest, ControlDependencies) {
+ Graph g(OpRegistry::Global());
+ {
+ Scope s = Scope::NewRootScope();
+ auto c0 = ops::Const<int>(s, 1);
+ auto recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1", "sender",
+ 0, "receiver");
+ auto c1 = ops::Const<int>(s.WithControlDependencies(recv1), 2);
+ auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender",
+ 0, "receiver");
+ auto c2 = ops::Const<int>(s.WithControlDependencies(recv2), 3);
+ auto add = ops::Add(s.WithControlDependencies(c2), c0, c1);
+ auto send =
+ ops::_Send(s.WithOpName("send"), add, "send", "sender", 0, "receiver");
+ TF_ASSERT_OK(s.ToGraph(&g));
+ }
+ bool was_mutated;
+ TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, Env::Default(),
+ nullptr, &g, &was_mutated));
+ EXPECT_TRUE(was_mutated);
+
+ std::unordered_map<string, Node*> index = NodeNameIndex(g);
+ Node* recv1 = index.at("recv1");
+ Node* recv2 = index.at("recv2");
+ Node* send = index.at("send");
+
+ ASSERT_EQ(1, send->num_inputs());
+ Node* p = *(send->in_nodes().begin());
+ ExpectNodeEqual<int>(p, {3}, {});
+
+ ASSERT_EQ(2, p->in_edges().size());
+ for (const Edge* e : p->in_edges()) {
+ EXPECT_TRUE(e->IsControlEdge());
+ EXPECT_TRUE(e->src() == recv1 || e->src() == recv2) << e->src()->name();
+ }
+}
+
namespace {
const char kTestMemRegionName[] = "test://test";
diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc
index 037c393574..945e0b068a 100644
--- a/tensorflow/core/ops/candidate_sampling_ops.cc
+++ b/tensorflow/core/ops/candidate_sampling_ops.cc
@@ -55,6 +55,7 @@ REGISTER_OP("UniformCandidateSampler")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetShapeFn(CandidateSamplerShapeFn)
+ .SetIsStateful()
.Doc(R"doc(
Generates labels for candidate sampling with a uniform distribution.
@@ -103,6 +104,7 @@ REGISTER_OP("LogUniformCandidateSampler")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetShapeFn(CandidateSamplerShapeFn)
+ .SetIsStateful()
.Doc(R"doc(
Generates labels for candidate sampling with a log-uniform distribution.
@@ -152,6 +154,7 @@ REGISTER_OP("LearnedUnigramCandidateSampler")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetShapeFn(CandidateSamplerShapeFn)
+ .SetIsStateful()
.Doc(R"doc(
Generates labels for candidate sampling with a learned unigram distribution.
@@ -200,6 +203,7 @@ REGISTER_OP("ThreadUnsafeUnigramCandidateSampler")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetShapeFn(CandidateSamplerShapeFn)
+ .SetIsStateful()
.Doc(R"doc(
Generates labels for candidate sampling with a learned unigram distribution.
@@ -254,6 +258,7 @@ REGISTER_OP("FixedUnigramCandidateSampler")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetShapeFn(CandidateSamplerShapeFn)
+ .SetIsStateful()
.Doc(R"doc(
Generates labels for candidate sampling with a learned unigram distribution.
@@ -329,6 +334,7 @@ REGISTER_OP("AllCandidateSampler")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetShapeFn(CandidateSamplerShapeFn)
+ .SetIsStateful()
.Doc(R"doc(
Generates labels for candidate sampling with a learned unigram distribution.