diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass.cc | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 590b3d030f..3f8b0e86d0 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -64,6 +64,15 @@ namespace tensorflow { // in the Mkl format. Non-compliant ops accept inputs and outputs in the // TensorFlow format. // +// ADDENDUM: For element-wise ops, we may or may not need a conversion to +// take place before we hit the op. For this, we add a new op before each +// element-wise MKL op to deal with the inputs, called _MklInputConversion. +// This pass has been enhanced to add this capability. +// +// The _MklInputConversion op will check the inputs to the elementwise op and +// make sure that either both are in MKL format or both are in TF format, +// depending on their initial state and whether broadcast is needed or not. + class MklToTfConversionPass : public GraphOptimizationPass { public: MklToTfConversionPass() {} @@ -87,6 +96,16 @@ class MklToTfConversionPass : public GraphOptimizationPass { return mkl_op_registry::IsMklOp(op_name, T); } + // Is the input Op supported by Mkl-specific layout AND + // is it element-wise? + // + // @input op_name string of the op + // @input T Datatype to use for checking input op + // @return true if op is Mkl supported; false, otherwise. + inline bool IsMklElementWiseOp(const string& op_name, DataType T) const { + return mkl_op_registry::IsMklElementWiseOp(op_name, T); + } + // Insert layout conversion node on the edge pointed by 'e' from graph 'g'. // // Edge will be deleted once a call to this function is successful. @@ -96,6 +115,17 @@ class MklToTfConversionPass : public GraphOptimizationPass { // @return Success:OK() if insertion is successful, otherwise returns // appropriate error status code. Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*); + + // For element-wise ops, we need to sanitize the inputs. For this, we add a + // new node at the input of the replacement element-wise node that checks + // the inputs and converts one/both of them as required. See the op code + // comments for details. + // + // Insert input conversion node as parent of 'n' from graph 'g'. + // + // @return Success:OK() if insertion is successful, otherwise returns + // appropriate error status code. + Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*); }; // We register MklToTf insertion for phase 2 in post-partition grouping @@ -171,6 +201,92 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( return Status::OK(); } +Status MklToTfConversionPass::InsertInputConversionNode( + std::unique_ptr<Graph>* g, Node* n) { + CHECK_NOTNULL(n); + + // Get the input nodes and edges + std::vector<const Edge*> edges; + TF_CHECK_OK(n->input_edges(&edges)); + if (edges.size() != 4) { + return Status(error::Code::INVALID_ARGUMENT, + "MKL Binary Element-wise op should have exactly 2 data" + " inputs and 2 metadata inputs"); + } + + // Sanity check: ensure that both inputs are of the expected type, and the + // same type as input type + CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), + BaseType(edges[1]->src()->output_type(edges[1]->src_output()))); + CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), + BaseType(n->input_type(0))); + + // Check ordering of edges + for (uint i = 0; i < 4; i++) { + CHECK_EQ((edges[i]->dst_input() == i), true); + } + + // Build the conversion node and specify src as input. + Node* conversion_node = nullptr; + + TF_CHECK_OK( + NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion") + .Input(edges[0]->src(), edges[0]->src_output()) + .Input(edges[1]->src(), edges[1]->src_output()) + .Input(edges[2]->src(), edges[2]->src_output()) + .Input(edges[3]->src(), edges[3]->src_output()) + .Device(n->def().device()) + .Attr("T", n->input_type(0)) + .Finalize(&**g, &conversion_node)); + + CHECK_NOTNULL(conversion_node); + + // Change the destination of any control edges to the InputConversion node + if (edges.size() != n->in_edges().size()) { + std::vector<const Edge*> edges_to_remove; + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node)); + edges_to_remove.push_back(e); + } + } + for (const Edge* e : edges_to_remove) { + (*g)->RemoveEdge(e); + } + } + + string data_format; + if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) == + Status::OK()) { + conversion_node->AddAttr("data_format", data_format); + } + + // Get assigned device from destination node and apply it to conversion node. + // We want conversion node to be on the same device as the destination node. + conversion_node->set_assigned_device_name(n->assigned_device_name()); + + // Set the Mkl op label for this op. + conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); + + // Now that we have added edges from src->conversion_node, let's add edge from + // output of conversion_node to the element-wise node. + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input())); + CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input())); + + VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input " + << "conversion node on: " << n->type_string() << " successful."; + + // Remove src->dst edge now. + (*g)->RemoveEdge(edges[0]); + (*g)->RemoveEdge(edges[1]); + (*g)->RemoveEdge(edges[2]); + (*g)->RemoveEdge(edges[3]); + + return Status::OK(); +} + bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { bool result = false; @@ -239,6 +355,49 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { DumpGraph("After MklToTfConversionPass", &**g); + //--------------------------------------------------------------------------- + // Check all nodes and add an input-conversion-node if the node is an mkl + // element-wise node. + VLOG(1) << "Before running MklToTfConversionPass - InputConversion"; + + std::vector<Node*> candidate_nodes; + std::vector<Node*> order; + GetReversePostOrder(**g, &order); // This will give us topological sort. + + for (Node* n : order) { + // If node is not an op or it does not have a datatype, then skip. + DataType datatype; + if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != Status::OK())) { + continue; + } + if (IsMklElementWiseOp(n->type_string(), datatype)) { + // If the input node is an input-conversion op, skip + Node* input_node = nullptr; + TF_CHECK_OK(n->input_node(0, &input_node)); + DataType input_datatype; + if ((GetNodeAttr(n->def(), "T", &input_datatype) == Status::OK()) && + (input_node->type_string().compare("_MklInputConversion") == 0)) { + continue; + } + + VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node " + << n->name() << " for inserting input conversion node"; + candidate_nodes.push_back(const_cast<Node*>(n)); + } + } + + // Process all candidate edges and insert conversion nodes on them. + for (Node* n : candidate_nodes) { + // Even if we insert conversion node on a single node, we + // need to return true. + if (InsertInputConversionNode(g, n) == Status::OK()) { + VLOG(1) << "MklToTfConversionPass: Inserted conversion " + << "on node " << n->name(); + result = true; + } + } + DumpGraph("After MklToTfConversionPass - InputConversion", &**g); + // We need to return true even if we insert one conversion node // anywhere in the graph. return result; |