diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass_test.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass_test.cc | 199 |
1 files changed, 199 insertions, 0 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc new file mode 100644 index 0000000000..10671ee2e9 --- /dev/null +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +#include "tensorflow/core/graph/mkl_layout_pass.h" +#include "tensorflow/core/util/mkl_util.h" + +#include <vector> +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +static void InitGraph(const string& s, Graph* graph) { + GraphDef graph_def; + + auto parser = protobuf::TextFormat::Parser(); + // parser.AllowRelaxedWhitespace(true); + CHECK(parser.MergeFromString(s, &graph_def)) << s; + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); +} + +class MklLayoutPassTest : public ::testing::Test { + public: + MklLayoutPassTest() : graph_(OpRegistry::Global()) {} + + void InitGraph(const string& s) { + ::tensorflow::InitGraph(s, &graph_); + original_ = CanonicalGraphString(&graph_); + } + + static bool IncludeNode(const Node* n) { return n->IsOp(); } + + static string EdgeId(const Node* n, int index) { + if (index == 0) { + return n->name(); + } else if (index == Graph::kControlSlot) { + return strings::StrCat(n->name(), ":control"); + } else { + return strings::StrCat(n->name(), ":", index); + } + } + + string CanonicalGraphString(Graph* g) { + std::vector<string> nodes; + std::vector<string> edges; + for (const Node* n : g->nodes()) { + if (IncludeNode(n)) { + nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")")); + } + } + for (const Edge* e : g->edges()) { + if (IncludeNode(e->src()) && IncludeNode(e->dst())) { + edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", + EdgeId(e->dst(), e->dst_input()))); + } + } + // Canonicalize + std::sort(nodes.begin(), nodes.end()); + std::sort(edges.begin(), edges.end()); + return strings::StrCat(str_util::Join(nodes, ";"), "|", + str_util::Join(edges, ";")); + } + + string DoMklLayoutOptimizationPass() { + string before = CanonicalGraphString(&graph_); + LOG(ERROR) << "Before MKL layout rewrite pass: " << before; + + std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_); + RunMklLayoutRewritePass(ug); + + string result = CanonicalGraphString(&graph_); + LOG(ERROR) << "After MKL layout rewrite pass: " << result; + return result; + } + + const string& OriginalGraph() const { return original_; } + + Graph graph_; + string original_; +}; + +REGISTER_OP("Input").Output("o: float").SetIsStateful(); + +// Single Conv2D Op; No Mkl layer on the input and on the output. +// We will generate dummy Mkl tensor as 2nd input of Conv2D. +TEST_F(MklLayoutPassTest, Conv2D_Basic) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Conv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B']}" + "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['B', 'C'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|" + "A->C;B->C:2;B->D;C->D:1;DMT/_0->C:1;DMT/_1->C:3"); +} + +// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will +// have 2 outputs, both of which will be inputs to next Conv2D. +TEST_F(MklLayoutPassTest, Conv2D_Positive1) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Conv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B']}" + "node { name: 'D' op: 'Conv2D'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'C']}" + "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(MklConv2D);D(MklConv2D);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:2;C->D:2;C->E;" + "C:1->D:3;D->E:1;DMT/_0->C:1;DMT/_1->C:3;DMT/_2->D:1"); +} + +static void BM_MklLayoutRewritePass(int iters, int op_nodes) { + testing::StopTiming(); + string s; + for (int in = 0; in < 10; in++) { + s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); + } + random::PhiloxRandom philox(301, 17); + random::SimplePhilox rnd(&philox); + for (int op = 0; op < op_nodes; op++) { + s += strings::Printf( + "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { " + "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", + op, rnd.Uniform(10), rnd.Uniform(10)); + } + + bool first = true; + while (iters > 0) { + Graph* graph = new Graph(OpRegistry::Global()); + InitGraph(s, graph); + int N = graph->num_node_ids(); + if (first) { + testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); + first = false; + } + { + testing::StartTiming(); + std::unique_ptr<Graph> ug(graph); + RunMklLayoutRewritePass(&ug); + testing::StopTiming(); + } + iters -= N; // Our benchmark units are individual graph nodes, + // not whole graphs + // delete graph; + } +} +BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); + +} // namespace +} // namespace tensorflow + +#endif /* INTEL_MKL */ |