aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_tfconversion_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc159
1 files changed, 159 insertions, 0 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 590b3d030f..3f8b0e86d0 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -64,6 +64,15 @@ namespace tensorflow {
// in the Mkl format. Non-compliant ops accept inputs and outputs in the
// TensorFlow format.
//
+// ADDENDUM: For element-wise ops, we may or may not need a conversion to
+// take place before we hit the op. For this, we add a new op before each
+// element-wise MKL op to deal with the inputs, called _MklInputConversion.
+// This pass has been enhanced to add this capability.
+//
+// The _MklInputConversion op will check the inputs to the elementwise op and
+// make sure that either both are in MKL format or both are in TF format,
+// depending on their initial state and whether broadcast is needed or not.
+
class MklToTfConversionPass : public GraphOptimizationPass {
public:
MklToTfConversionPass() {}
@@ -87,6 +96,16 @@ class MklToTfConversionPass : public GraphOptimizationPass {
return mkl_op_registry::IsMklOp(op_name, T);
}
+ // Is the input Op supported by Mkl-specific layout AND
+ // is it element-wise?
+ //
+ // @input op_name string of the op
+ // @input T Datatype to use for checking input op
+ // @return true if op is Mkl supported; false, otherwise.
+ inline bool IsMklElementWiseOp(const string& op_name, DataType T) const {
+ return mkl_op_registry::IsMklElementWiseOp(op_name, T);
+ }
+
// Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
//
// Edge will be deleted once a call to this function is successful.
@@ -96,6 +115,17 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// @return Success:OK() if insertion is successful, otherwise returns
// appropriate error status code.
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
+
+ // For element-wise ops, we need to sanitize the inputs. For this, we add a
+ // new node at the input of the replacement element-wise node that checks
+ // the inputs and converts one/both of them as required. See the op code
+ // comments for details.
+ //
+ // Insert input conversion node as parent of 'n' from graph 'g'.
+ //
+ // @return Success:OK() if insertion is successful, otherwise returns
+ // appropriate error status code.
+ Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*);
};
// We register MklToTf insertion for phase 2 in post-partition grouping
@@ -171,6 +201,92 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
return Status::OK();
}
+Status MklToTfConversionPass::InsertInputConversionNode(
+ std::unique_ptr<Graph>* g, Node* n) {
+ CHECK_NOTNULL(n);
+
+ // Get the input nodes and edges
+ std::vector<const Edge*> edges;
+ TF_CHECK_OK(n->input_edges(&edges));
+ if (edges.size() != 4) {
+ return Status(error::Code::INVALID_ARGUMENT,
+ "MKL Binary Element-wise op should have exactly 2 data"
+ " inputs and 2 metadata inputs");
+ }
+
+ // Sanity check: ensure that both inputs are of the expected type, and the
+ // same type as input type
+ CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
+ BaseType(edges[1]->src()->output_type(edges[1]->src_output())));
+ CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
+ BaseType(n->input_type(0)));
+
+ // Check ordering of edges
+ for (uint i = 0; i < 4; i++) {
+ CHECK_EQ((edges[i]->dst_input() == i), true);
+ }
+
+ // Build the conversion node and specify src as input.
+ Node* conversion_node = nullptr;
+
+ TF_CHECK_OK(
+ NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion")
+ .Input(edges[0]->src(), edges[0]->src_output())
+ .Input(edges[1]->src(), edges[1]->src_output())
+ .Input(edges[2]->src(), edges[2]->src_output())
+ .Input(edges[3]->src(), edges[3]->src_output())
+ .Device(n->def().device())
+ .Attr("T", n->input_type(0))
+ .Finalize(&**g, &conversion_node));
+
+ CHECK_NOTNULL(conversion_node);
+
+ // Change the destination of any control edges to the InputConversion node
+ if (edges.size() != n->in_edges().size()) {
+ std::vector<const Edge*> edges_to_remove;
+ for (const Edge* e : n->in_edges()) {
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node));
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (const Edge* e : edges_to_remove) {
+ (*g)->RemoveEdge(e);
+ }
+ }
+
+ string data_format;
+ if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
+ Status::OK()) {
+ conversion_node->AddAttr("data_format", data_format);
+ }
+
+ // Get assigned device from destination node and apply it to conversion node.
+ // We want conversion node to be on the same device as the destination node.
+ conversion_node->set_assigned_device_name(n->assigned_device_name());
+
+ // Set the Mkl op label for this op.
+ conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
+
+ // Now that we have added edges from src->conversion_node, let's add edge from
+ // output of conversion_node to the element-wise node.
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input()));
+ CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input()));
+
+ VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input "
+ << "conversion node on: " << n->type_string() << " successful.";
+
+ // Remove src->dst edge now.
+ (*g)->RemoveEdge(edges[0]);
+ (*g)->RemoveEdge(edges[1]);
+ (*g)->RemoveEdge(edges[2]);
+ (*g)->RemoveEdge(edges[3]);
+
+ return Status::OK();
+}
+
bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
bool result = false;
@@ -239,6 +355,49 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
DumpGraph("After MklToTfConversionPass", &**g);
+ //---------------------------------------------------------------------------
+ // Check all nodes and add an input-conversion-node if the node is an mkl
+ // element-wise node.
+ VLOG(1) << "Before running MklToTfConversionPass - InputConversion";
+
+ std::vector<Node*> candidate_nodes;
+ std::vector<Node*> order;
+ GetReversePostOrder(**g, &order); // This will give us topological sort.
+
+ for (Node* n : order) {
+ // If node is not an op or it does not have a datatype, then skip.
+ DataType datatype;
+ if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != Status::OK())) {
+ continue;
+ }
+ if (IsMklElementWiseOp(n->type_string(), datatype)) {
+ // If the input node is an input-conversion op, skip
+ Node* input_node = nullptr;
+ TF_CHECK_OK(n->input_node(0, &input_node));
+ DataType input_datatype;
+ if ((GetNodeAttr(n->def(), "T", &input_datatype) == Status::OK()) &&
+ (input_node->type_string().compare("_MklInputConversion") == 0)) {
+ continue;
+ }
+
+ VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node "
+ << n->name() << " for inserting input conversion node";
+ candidate_nodes.push_back(const_cast<Node*>(n));
+ }
+ }
+
+ // Process all candidate edges and insert conversion nodes on them.
+ for (Node* n : candidate_nodes) {
+ // Even if we insert conversion node on a single node, we
+ // need to return true.
+ if (InsertInputConversionNode(g, n) == Status::OK()) {
+ VLOG(1) << "MklToTfConversionPass: Inserted conversion "
+ << "on node " << n->name();
+ result = true;
+ }
+ }
+ DumpGraph("After MklToTfConversionPass - InputConversion", &**g);
+
// We need to return true even if we insert one conversion node
// anywhere in the graph.
return result;