diff options
author | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
---|---|---|
committer | Manjunath Kudlur <keveman@gmail.com> | 2015-11-06 16:27:58 -0800 |
commit | f41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch) | |
tree | ef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/graph/optimizer_cse_test.cc |
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation
using data flow graphs.
Base CL: 107276108
Diffstat (limited to 'tensorflow/core/graph/optimizer_cse_test.cc')
-rw-r--r-- | tensorflow/core/graph/optimizer_cse_test.cc | 365 |
1 files changed, 365 insertions, 0 deletions
diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc new file mode 100644 index 0000000000..ebbb948fdc --- /dev/null +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -0,0 +1,365 @@ +#include "tensorflow/core/graph/optimizer_cse.h" + +#include <gtest/gtest.h> +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/tensor.h" + +namespace tensorflow { +namespace { + +static void InitGraph(const string& s, Graph* graph) { + GraphDef graph_def; + + auto parser = protobuf::TextFormat::Parser(); + // parser.AllowRelaxedWhitespace(true); + CHECK(parser.MergeFromString(s, &graph_def)) << s; + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); +} + +class OptimizerCSETest : public ::testing::Test { + public: + OptimizerCSETest() : graph_(OpRegistry::Global()) { RequireDefaultOps(); } + + void InitGraph(const string& s) { + ::tensorflow::InitGraph(s, &graph_); + original_ = CanonicalGraphString(&graph_); + } + + static bool IncludeNode(const Node* n) { return n->IsOp(); } + + static string EdgeId(const Node* n, int index) { + if (index == 0) { + return n->name(); + } else if (index == Graph::kControlSlot) { + return strings::StrCat(n->name(), ":control"); + } else { + return strings::StrCat(n->name(), ":", index); + } + } + + string CanonicalGraphString(Graph* g) { + std::vector<string> nodes; + std::vector<string> edges; + for (const Node* n : g->nodes()) { + if (IncludeNode(n)) { + nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")")); + } + } + for (const Edge* e : g->edges()) { + if (IncludeNode(e->src()) && IncludeNode(e->dst())) { + edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", + EdgeId(e->dst(), e->dst_input()))); + } + } + // Canonicalize + std::sort(nodes.begin(), nodes.end()); + std::sort(edges.begin(), edges.end()); + return strings::StrCat(str_util::Join(nodes, ";"), "|", + str_util::Join(edges, ";")); + } + + string DoCSE(std::function<bool(const Node*)> consider_fn = nullptr) { + string before = CanonicalGraphString(&graph_); + LOG(ERROR) << "Before rewrites: " << before; + + OptimizeCSE(&graph_, consider_fn); + + string result = CanonicalGraphString(&graph_); + LOG(ERROR) << "After rewrites: " << result; + return result; + } + + const string& OriginalGraph() const { return original_; } + + Graph graph_; + string original_; +}; + +REGISTER_OP("Input").Output("o: float").SetIsStateful(); + +// Note that the "rules" in these tests are not meant to be logically correct +TEST_F(OptimizerCSETest, Simple) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, Simple_ThreeEquivalent) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);E(Mul)|" + "A->E;B->E:1"); +} + +TEST_F(OptimizerCSETest, Simple_WithFixups) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul);E(Mul)|" + "A->D;B->D:1;D->E;D->E:1"); +} + +TEST_F(OptimizerCSETest, Simple_Commutative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D:1;B->D"); +} + +static bool IsNotMultiply(const Node* n) { return n->type_string() != "Mul"; } + +// Like Simple_Commutative, +TEST_F(OptimizerCSETest, Simple_Filtered) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(IsNotMultiply), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Simple_NotCommutative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'A'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, NotEquivalent_Ops) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'D' op: 'Sub' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs1) { + // Should still do CSE for ops with attrs if they match. + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] attr { key: 'shape'" + " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] attr { key: 'shape'" + " value { shape: { dim: { size: 37 name: 'SAME_NAME' } } } } }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, Simple_SameOps_SameAttrs2) { + // Should still do CSE for ops with attrs if they match, even if they + // are not in the same order. + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_INT32 } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 3 } } }"); + EXPECT_EQ(DoCSE(), + "A(Input);B(Input);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, SameConstants) { + // Should still do CSE for ops with constants if the values are identical + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "B(Const);D(Mul)|" + "B->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, DifferentConstants) { + // Should still do CSE for ops with extensions if the extensions are identical + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT32 } }" + " attr { key: 'value' value {" + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 100000 } } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_INT32 } }" + " input: ['A', 'B'] }"); + EXPECT_EQ(DoCSE(), + "A(Const);B(Const);D(Mul)|" + "A->D;B->D:1"); +} + +TEST_F(OptimizerCSETest, SameOps_DifferentAttrs1) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_INT32 } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 4 } } }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, SameOps_DifferentAttrs2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 'a' value { i: 3 } }" + " attr { key: 't' value { type: DT_FLOAT } } }" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']" + " attr { key: 't' value { type: DT_INT32 } }" + " attr { key: 'a' value { i: 3 } } }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, NotEquivalent_Inputs) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B'] }" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'C'] }"); + EXPECT_EQ(DoCSE(), OriginalGraph()); +} + +TEST_F(OptimizerCSETest, Constant_Dedup) { + Tensor a(DT_FLOAT, TensorShape({1})); + a.flat<float>()(0) = 1.0; + Tensor b(DT_DOUBLE, TensorShape({1})); // Different type + b.flat<double>()(0) = 1.0; + Tensor c(DT_FLOAT, TensorShape({1, 1})); // Different shape + c.flat<float>()(0) = 1.0; + Tensor d(DT_FLOAT, TensorShape({1})); // Different value + d.flat<float>()(0) = 2.0; + + // A graph contains a bunch of constants. + Graph g(OpRegistry::Global()); + for (auto val : {a, b, c, d, d, c, b, a}) { + test::graph::Constant(&g, val); // Node name is n/_0, n/_1, ... + } + GraphDef gdef; + test::graph::ToGraphDef(&g, &gdef); + InitGraph(gdef.DebugString()); + + EXPECT_EQ(OriginalGraph(), + "n/_0(Const);n/_1(Const);n/_2(Const);n/_3(Const);" + "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); + // In theory, there are 2^4 possible correct output of CSE. In this + // test, it happens happens to eliminate the first 4 nodes. + EXPECT_EQ(DoCSE(), "n/_4(Const);n/_5(Const);n/_6(Const);n/_7(Const)|"); +} + +static void BM_CSE(int iters, int op_nodes) { + testing::StopTiming(); + string s; + for (int in = 0; in < 10; in++) { + s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); + } + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int op = 0; op < op_nodes; op++) { + s += strings::Printf( + "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { " + "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", + op, rnd.Uniform(10), rnd.Uniform(10)); + } + + bool first = true; + while (iters > 0) { + Graph* graph = new Graph(OpRegistry::Global()); + InitGraph(s, graph); + int N = graph->num_node_ids(); + if (first) { + testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); + first = false; + } + { + testing::StartTiming(); + OptimizeCSE(graph, nullptr); + testing::StopTiming(); + } + iters -= N; // Our benchmark units are individual graph nodes, + // not whole graphs + delete graph; + } +} +BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000); + +} // namespace +} // namespace tensorflow |