aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_tfconversion_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc271
1 files changed, 271 insertions, 0 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
new file mode 100644
index 0000000000..1e7b5e7094
--- /dev/null
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -0,0 +1,271 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include <set>
+#include <vector>
+#include <queue>
+#include <utility>
+#include <string>
+#include <memory>
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+
+#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+
+// This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
+// in the graph in between A and B, where A and B match any one
+// of the following
+// cases:
+// 1) A = layer/Op that generates output in Mkl format and,
+// B = layer/Op that does not accept input in Mkl format and,
+// A -> B (there is a direct edge between A and B, then
+// We will insert C such that A->C->B.
+//
+// 2) A = layer/Op that generates output in Mkl format and,
+// B = NULL (in other words, A is the last layer in the graph), then
+// We will insert C such that A->C->B. (C will be the last layer.)
+//
+// Note that case 1 applies to all outputs of A that are input to B.
+// In other words, the conversions will be required for every output
+// of A that is input to B. For example, let us say the output of A
+// is A1, A2, A3, of which A1 and A2 are in Mkl format, but A3 is not
+// in Mkl format, and all of them are input to B. In such case, we will
+// do the conversion for A1 and A2 only. We do not need to do any conversion
+// for A3.
+//
+// This pass relies on layers registering themselves about their Mkl compliant.
+// Mkl compliant layer can accept inputs in Mkl format, and produce output in
+// Mkl format. Non-compliant layer accepts inputs and outputs in
+// TensorFlow format.
+//
+class MklToTfConversionPass : public GraphOptimizationPass {
+ public:
+ MklToTfConversionPass() {}
+ Status Run(const GraphOptimizationPassOptions& options);
+
+ // Insert layout conversion node in the graph pointed by g.
+ // Function scans the graph for candidate edges where we
+ // need to insert conversion nodes.
+ //
+ // @return true even if single conversion node is inserted;
+ // false, otherwise.
+ bool RunPass(std::unique_ptr<Graph>* g);
+
+
+ private:
+ // Is the input Op supported by Mkl-specific layout?
+ //
+ // @input op_name string of the op
+ // @return true if op is Mkl supported; false, otherwise.
+ inline bool IsMklSupportedOp(const string& op_name) const {
+ return mkl_layer_registry::IsMklLayer(op_name);
+ }
+
+ // Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
+ //
+ // Edge will be deleted once a call to this function is successful.
+ // Any attempt to use the edge after this call
+ // will lead to undefined behaviors.
+ //
+ // @return Success:OK() if insertion is successful, otherwise returns
+ // appropriate error status code.
+ Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
+};
+
+// We register MklToTf insertion for phase 1 in post-partition grouping.
+// We register this pass after partitioning so that we get a complete
+// picture of inputs and outputs of the nodes in the graphs.
+const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
+ OptimizationPassRegistry::POST_PARTITIONING;
+REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 1, MklToTfConversionPass);
+
+Status MklToTfConversionPass::InsertConversionNodeOnEdge(
+ std::unique_ptr<Graph>* g, Edge *e) {
+ CHECK_NOTNULL(e);
+
+ Node* src = e->src();
+ Node* dst = e->dst();
+
+ CHECK_NOTNULL(src);
+ CHECK_NOTNULL(dst);
+
+ Node* conversion_node = nullptr;
+ DataType src_datatype = DT_INVALID;
+ DataType dst_datatype = DT_INVALID;
+ string data_format;
+
+ TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
+ TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype));
+ if (src_datatype != dst_datatype) {
+ string err_msg = "T attribute of " + src->name() + " and " +
+ dst->name() + " do not match. Will not insert" +
+ " MklToTf node in such case.";
+ return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
+ }
+
+ // Lets build the conversion node and specify src as input.
+ TF_CHECK_OK(NodeBuilder((*g)->NewName("Mkl2Tf"), "MklToTf")
+ .Input(src, e->src_output())
+ .Input(src, e->src_output()+1) // Mkl tensor immediately
+ // follows Tf tensor.
+ .Device(src->def().device()) // We want to get conversion node
+ // on same device as source node.
+ .Attr("T", src_datatype)
+ .Finalize(&**g, &conversion_node));
+
+ CHECK_NOTNULL(conversion_node);
+ if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
+ conversion_node->AddAttr("data_format", data_format);
+ }
+
+ // Get assigned device from source node and apply it to conversion node.
+ // We want conversion node to be on the same device as the source node.
+ conversion_node->set_assigned_device_name(src->assigned_device_name());
+
+ // Set the Mkl layer label for this op.
+ conversion_node->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel);
+
+ // Now that we have added edge from src->conversion_node, let's add edge from
+ // output of conversion_node to the dest node. Since conversion_node
+ // has only 1 output, the src_output of conversion_node is 0.
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, dst, e->dst_input()));
+
+ VLOG(1) << "MklToTfConversionPass: Inserting Conversion node on: "
+ << src->type_string() << " and " << dst->type_string()
+ << " successful.";
+
+ // Remove src->dst edge now.
+ (*g)->RemoveEdge(e);
+ return Status::OK();
+}
+
+bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
+ bool result = false;
+
+ CHECK_NOTNULL(g);
+
+ DumpGraph("Before MklToTfConversionPass", &**g);
+
+ // Since we are looking for mkl-supported op node immediately
+ // followed by non-mkl op node, we will just iterate over edge
+ // set of the graph.
+ // vector to maintain candiadate edges whose source and destination
+ // are candidate for inserting conversion node
+ std::vector<Edge*> candidate_edges;
+
+ for (const Edge *e : (*g)->edges()) {
+ Node* src = e->src();
+ Node* dst = e->dst();
+
+ // We skip control edges.
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: "
+ << src->type_string() << " and " << dst->type_string();
+
+ // Let's get source and destination data type.
+ DataType src_datatype = DT_INVALID;
+ if (GetNodeAttr(src->def(), "T", &src_datatype) != Status::OK()) {
+ continue;
+ }
+ // We cannot check datatype on destination node because destination node
+ // may not be Mkl node.
+ DataType dst_datatype = DT_INVALID;
+ GetNodeAttr(dst->def(), "T", &dst_datatype);
+
+ // Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
+ if (IsMklSupportedOp(src->type_string()) &&
+ !IsMklSupportedOp(dst->type_string())) {
+ VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
+ << " and " << dst->name() << " for inserting conversion nodes";
+ candidate_edges.push_back(const_cast<Edge*>(e));
+ }
+ }
+
+ // Process all candidate edges and insert conversion nodes on them.
+ for (Edge* e : candidate_edges) {
+ // Even if we insert conversion node on a single edge, we
+ // need to return true.
+ string src_name = e->src()->name();
+ string dst_name = e->dst()->name();
+ if (InsertConversionNodeOnEdge(g, e) == Status::OK()) {
+ VLOG(1) << "MklToTfConversionPass: Inserted conversion "
+ << "node on edge between " << src_name << " and " << dst_name;
+ result = true;
+ }
+ }
+
+ DumpGraph("After MklToTfConversionPass", &**g);
+
+ // We need to return true even if we insert one conversion node
+ // anywhere in the graph.
+ return result;
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// Run function for the pass
+//////////////////////////////////////////////////////////////////////////////
+
+bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) {
+ return MklToTfConversionPass().RunPass(g);
+}
+
+Status MklToTfConversionPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ if (options.graph == nullptr && options.partition_graphs == nullptr) {
+ return Status::OK();
+ }
+
+ auto process_graph = [&](std::unique_ptr<Graph>* g) {
+ // Get the ownership of graph
+ std::unique_ptr<Graph>* ng = std::move(g);
+ RunPass(ng);
+ // Return the ownership of graph back
+ g->reset(ng->release());
+ };
+
+ if (kMklTfConvPassGroup != OptimizationPassRegistry::POST_PARTITIONING) {
+ // For any pre-partitioning phase, graph is stored in options.graph.
+ process_graph(options.graph);
+ } else {
+ // For post partitioning phase, graphs are stored in
+ // options.partition_graphs.
+ for (auto& pg : *options.partition_graphs) {
+ process_graph(&pg.second);
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
+
+#endif