aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/optimizer_cse.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/optimizer_cse.cc')
-rw-r--r--tensorflow/core/graph/optimizer_cse.cc220
1 files changed, 220 insertions, 0 deletions
diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc
new file mode 100644
index 0000000000..2fa6f075c0
--- /dev/null
+++ b/tensorflow/core/graph/optimizer_cse.cc
@@ -0,0 +1,220 @@
+// This module implements a common subexpression elimination pass. We
+// process the nodes in the graph in reverse postorder
+// (i.e. inputs before their downstream dependencies). The rough algorithm is
+// as follows:
+//
+// std::unordered_map<size_t, Node*> available
+// for each node n in forward topological order:
+// h = NodeHash(n)
+// if available[h] exists and Equivalent(available(h), h)
+// redirect downstream uses of outputs of n to available[h]
+// remove n from graph
+// else
+// if available[h] does not exist
+// available[h] = n
+//
+// This is similar to the global value number algorithm describe in this
+// paper:
+//
+// "Global code motion/global value numbering", Cliff Click, PLDI '95
+// Proceedings of the ACM SIGPLAN 1995 conference on Programming
+// language design and implementation, Pages 246-257
+// http://dl.acm.org/citation.cfm?id=207154
+
+#include "tensorflow/core/graph/optimizer_cse.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class OptimizerCSE {
+ public:
+ explicit OptimizerCSE(Graph* g) : g_(g) {}
+
+ void Optimize(std::function<bool(const Node*)> consider_fn);
+
+ private:
+ struct Scratch;
+
+ static size_t NodeHash(const Node* n);
+ static bool Equivalent(const Node* a, const Node* b, Scratch* s);
+ static bool EqualAttrs(const Node* a, const Node* b, Scratch* s);
+
+ Graph* g_;
+};
+
+static void FillInputs(const Node* n,
+ gtl::InlinedVector<Node*, 4>* control_edges,
+ gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
+ DCHECK_EQ(in->size(), n->num_inputs());
+ control_edges->clear();
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges->push_back(e->src());
+ } else {
+ (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
+ }
+ }
+ std::sort(control_edges->begin(), control_edges->end());
+ if (n->op_def().is_commutative()) {
+ // For commutative inputs, we sort the input by the input Node*
+ // to get a canonical ordering (so that add(a,b) and add(b, a) will
+ // hash to the same value if is_commutative is true for 'add').
+ std::sort(in->begin(), in->end());
+ }
+}
+
+static size_t kIllegalNodeHash = 0;
+
+size_t OptimizerCSE::NodeHash(const Node* n) {
+ const DataTypeVector& out = n->output_types();
+ string str_to_hash = strings::StrCat(n->type_string(), out.size());
+ for (DataType dt : out) {
+ strings::StrAppend(&str_to_hash, dt);
+ }
+
+ const int N_in = n->num_inputs();
+ strings::StrAppend(&str_to_hash, N_in);
+ gtl::InlinedVector<Node*, 4> control_edges;
+ gtl::InlinedVector<std::pair<Node*, int>, 4> in(N_in);
+ FillInputs(n, &control_edges, &in);
+ for (const auto& edge : in) {
+ strings::StrAppend(&str_to_hash, edge.first->id(), edge.second);
+ }
+
+ size_t h = Hash64(str_to_hash);
+
+#if !defined(__ANDROID__) && !defined(ANDROID)
+ // Hash the attrs. For example, this makes sure different constants
+ // end up in different hash buckets.
+ string tmp;
+ for (const auto& attr : n->def().attr()) {
+ tmp = attr.first;
+ attr.second.AppendToString(&tmp);
+ // Add hashes of attrs, so the order of attrs doesn't matter.
+ h += Hash32(tmp.data(), tmp.size(), 0x87341245);
+ }
+#endif
+
+ if (h == kIllegalNodeHash) h = kIllegalNodeHash + 1;
+ return h;
+}
+
+struct OptimizerCSE::Scratch {
+ // For EqualAttrs():
+ string a;
+ string b;
+};
+
+bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) {
+ if (a->def().attr_size() != b->def().attr_size()) return false;
+
+ for (const auto& attr : b->def().attr()) {
+ auto iter = a->def().attr().find(attr.first);
+ if (iter == a->def().attr().end()) return false;
+ // Note: it should be safe to compare proto serializations of the attr
+ // values since at most one field should be set in each (indeed, it
+ // should be the same field).
+ iter->second.SerializeToString(&scratch->a);
+ attr.second.SerializeToString(&scratch->b);
+ if (scratch->a != scratch->b) return false;
+ }
+ return true;
+}
+
+static bool HasRefInput(const Node* n) {
+ for (auto dt : n->input_types()) {
+ if (IsRefType(dt)) return true;
+ }
+ return false;
+}
+
+bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
+ // Different op names are different
+ if (a->type_string() != b->type_string()) return false;
+
+ // Never consider stateful nodes (such as non-const inputs) equivalent.
+ if (a->op_def().is_stateful()) return false;
+
+ // For now, we consider any node that takes a ref input to not be
+ // equivalent to any other node.
+ if (HasRefInput(a) || HasRefInput(b)) return false;
+
+ // Compare attrs. Note that equal attrs implies equal input and
+ // output types.
+ if (!EqualAttrs(a, b, scratch)) return false;
+
+ // Compare input sources
+ if (a->num_inputs() != b->num_inputs()) return false;
+ const int N_in = a->num_inputs();
+ gtl::InlinedVector<Node*, 4> a_control_edges;
+ gtl::InlinedVector<Node*, 4> b_control_edges;
+ gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
+ gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
+ FillInputs(a, &a_control_edges, &a_in);
+ FillInputs(b, &b_control_edges, &b_in);
+ if (a_in != b_in) return false;
+ if (a_control_edges != b_control_edges) return false;
+
+ return true;
+}
+
+void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) {
+ // This very simple implementation works if the whole graph is one
+ // giant basic block (because we just traverse nodes in a
+ // topological order). We'll need to do something more
+ // sophisticated when we have control flow/loops/etc.
+
+ // TODO(jeff): We need to handle Update nodes specially, but dealing
+ // with more general control flow will also solve this issue, and for
+ // now, our updates are almost always the most downstream nodes in
+ // the graph.
+ std::vector<Node*> order;
+ GetReversePostOrder(*g_, &order);
+
+ // Our value is just a single Node*, meaning we keep just a single
+ // candidate for a given node hash value. This may cause us to
+ // (rarely) lose some optimization opportunities if there are
+ // hash collisions, but it allows us to avoid having the value
+ // be a set<Node*> (or equivalent).
+ std::unordered_map<size_t, Node*> available;
+
+ // Scratch space for Equivalent calls. Allocated here and passed in to
+ // Equivalent to avoid allocation inside the loop below.
+ Scratch scratch;
+ for (Node* n : order) {
+ if (!n->IsOp()) continue;
+
+ // See if we should consider this node at all
+ if (consider_fn != nullptr && !consider_fn(n)) continue;
+
+ size_t h = NodeHash(n);
+ Node** candidate = &available[h];
+ if (*candidate == nullptr) {
+ // No existing match: insert "n" into the hash table under "h"
+ *candidate = n;
+ } else if (Equivalent(*candidate, n, &scratch)) {
+ VLOG(1) << "CSE: equivalent: " << (*candidate)->name() << " and "
+ << n->name();
+ // *candidate and n are equivalent. Therefore, we can replace
+ // n with *candidate by fixing up outgoing edges from "n" to instead
+ // come from "*candidate", and then delete n from the graph
+ for (const Edge* e : n->out_edges()) {
+ g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input());
+ }
+ g_->RemoveNode(n);
+ }
+ }
+}
+
+void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) {
+ OptimizerCSE opt(g);
+ opt.Optimize(consider_fn);
+}
+
+} // namespace tensorflow