diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_optimizer_merge.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_optimizer_merge.cc | 124 |
1 files changed, 90 insertions, 34 deletions
diff --git a/tensorflow/core/graph/mkl_optimizer_merge.cc b/tensorflow/core/graph/mkl_optimizer_merge.cc index 98fc268d28..bc5915eda2 100644 --- a/tensorflow/core/graph/mkl_optimizer_merge.cc +++ b/tensorflow/core/graph/mkl_optimizer_merge.cc @@ -22,6 +22,8 @@ limitations under the License. #include <vector> #include <queue> #include <utility> +#include <string> +#include <memory> #include "tensorflow/core/graph/mkl_optimizer_merge.h" @@ -33,6 +35,8 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { @@ -58,8 +62,8 @@ static size_t kNodeMergeContextMaxDepth = 10; class NodeMergeRewritePass : public GraphOptimizationPass { public: NodeMergeRewritePass() { - csinfo_.conv2d = "Conv2D"; - csinfo_.conv2dwithbias = "Conv2DWithBias"; + csinfo_.conv2d = "MklConv2D"; + csinfo_.conv2dwithbias = "MklConv2DWithBias"; csinfo_.conv2dwithbiasbackpropbias = "Conv2DWithBiasBackpropBias"; csinfo_.biasadd = "BiasAdd"; csinfo_.matmul = "MatMul"; @@ -72,6 +76,9 @@ class NodeMergeRewritePass : public GraphOptimizationPass { // maxhops in backward data-flow graph. Since input of forward nodes // (Conv2D) directly goes to backward nodes, we do not expect the // hop-distance would be more than few nodes. + // TODO(nhasabni) Temporarily disabling rewrite of BiasAddGrad. + // Will enable it once we support Conv2DWithBiasBackpropBias op. +#if 0 rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias, {csinfo_.conv2dwithbias, kNodeMergeContextMaxDepth}}); rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias, @@ -80,6 +87,7 @@ class NodeMergeRewritePass : public GraphOptimizationPass { // because we do not have a separate Op for MatMulwithBias. rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.biasaddgrad, {csinfo_.matmul, kNodeMergeContextMaxDepth}}); +#endif } // Standard interface to run optimization pass @@ -182,10 +190,16 @@ class NodeMergeRewritePass : public GraphOptimizationPass { // @return Matching rewriteinfo in case a match is found; null otherwise. const RewriteInfo* FindMatchingRewriteInfo(const Node* n, const Node** fwdn) const; + + // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, + // and return it in '*out'. + // TODO(nhasabni) We should move this to mkl_util.h + void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out); }; -/// We register merge optimizer for phase 1 and MKLToTF insertion for phase 2. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1, +// We register merge optimizer for phase 2 in pre-placement group. +// Do not change the ordering of the Mkl passes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 2, NodeMergeRewritePass); static void FillInputs(const Node* n, @@ -219,8 +233,6 @@ Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const { } } - VLOG(1) << "FindNodeForMerge: " << a->type_string(); - for (const MergeInfo* mi : matching_mi) { const int N_in = a->num_inputs(); if (mi->op >= N_in) { @@ -240,8 +252,6 @@ Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const { continue; } - VLOG(1) << " FindNode: " << b->type_string(); - gtl::InlinedVector<Node*, 4> b_control_edges; gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in); FillInputs(b, &b_control_edges, &b_in); @@ -258,6 +268,22 @@ Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const { return nullptr; } +void NodeMergeRewritePass::GetDummyMklTensorNode( + std::unique_ptr<Graph>* g, Node** out) { + const DataType dt = DataTypeToEnum<uint8>::v(); + TensorProto proto; + proto.set_dtype(dt); + uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + proto.set_tensor_content(const_cast<const void*>( + static_cast<void*>(&zero)), 8); + TensorShape dummy_shape({8}); + dummy_shape.AsProto(proto.mutable_tensor_shape()); + TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") + .Attr("value", proto) + .Attr("dtype", dt) + .Finalize(&**g, out)); +} + Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred) { CHECK_NOTNULL(succ); @@ -271,7 +297,6 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, std::vector<int32> strides; string data_format_pred, data_format_succ; bool use_cudnn_on_gnu; - int groups = 1; TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); @@ -280,25 +305,28 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); - // Groups attribute may not be there on the input node. So we do not - // check for error in GetNodeAttr call. - GetNodeAttr(pred->def(), "groups", &groups); // We check to ensure that data formats of both succ and pred are same. // We expect them to be same, so we can enforce this as assert. // But assert can be too strict, so we enforce this as a check. // If the check fails, then we do not merge two nodes. + // We also do same check for devices. if (data_format_pred != data_format_succ || - T_pred != T_succ) { + T_pred != T_succ || + pred->assigned_device_name() != succ->assigned_device_name() || + pred->def().device() != succ->def().device()) { return Status(error::Code::INVALID_ARGUMENT, - "data_format or T attribute of Conv2D and BiasAdd" - "do not match. Will skip node merge optimization"); + "data_format or T attribute or devices of Conv2D and " + "BiasAdd do not match. Will skip node merge optimization"); } // 2. Get inputs from both the nodes. // Find the 2 inputs from the conv and the bias from the add Bias. Node* oper1 = nullptr; + Node* oper1_mkl = nullptr; // Mkl tensor corresponding to oper1 Node* oper2 = nullptr; + Node* oper2_mkl = nullptr; // Mkl tensor corresponding to oper2 Node* oper3 = nullptr; + Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3 const int succ_num = succ->num_inputs(); gtl::InlinedVector<Node*, 4> succ_control_edges; @@ -326,24 +354,35 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, } } - // Get operand 0, 1 of conv2D - oper1 = pred_in[0].first; - oper2 = pred_in[1].first; + // Get operand 0, 1 of conv2D and their Mkl tensors. + CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs. + oper1 = pred_in[0].first; + oper1_mkl = pred_in[1].first; + oper2 = pred_in[2].first; + oper2_mkl = pred_in[3].first; // Get operand 1 of add_bias - oper3 = succ_in[1].first; + // BiasAdd must have 2 inputs: Conv, bias + CHECK_EQ(succ->in_edges().size(), 2); + oper3 = succ_in[1].first; + GetDummyMklTensorNode(g, &oper3_mkl); // Get dummy Mkl tensor node + // as BiasAdd does not have Mkl tensor as input. + CHECK_NOTNULL(oper3_mkl); Node* ret; // We will use the node name of BiasAdd as the name of new node TF_CHECK_OK(NodeBuilder(succ->name(), csinfo_.conv2dwithbias) .Input(oper1) + .Input(oper1_mkl) .Input(oper2) + .Input(oper2_mkl) .Input(oper3) + .Input(oper3_mkl) .Attr("T", T_pred) .Attr("strides", strides) .Attr("padding", padding) .Attr("data_format", data_format_pred) .Attr("use_cudnn_on_gpu", use_cudnn_on_gnu) - .Attr("groups", groups) + .Device(succ->def().device()) .Finalize(&**g, &ret)); CHECK_NOTNULL(ret); @@ -352,6 +391,15 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, (*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input()); } + // Copy device assigned to old node to new node. + // It's ok to use pred or succ as we have enforced a check that + // both have same device assigned. + ret->set_assigned_device_name(pred->assigned_device_name()); + + VLOG(1) << "NodeMergeRewritePass: Merged old node:" << pred->DebugString() + << ", and node: " << succ->DebugString() << ", into node:" + << ret->DebugString(); + (*g)->RemoveNode(succ); (*g)->RemoveNode(pred); @@ -369,13 +417,14 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) { const Node* fwdn = nullptr; const RewriteInfo* ri = FindMatchingRewriteInfo(n, &fwdn); if (ri == nullptr || fwdn == nullptr) { - VLOG(1) << "Rewriteinfo not found for: " << n->type_string(); + VLOG(2) << "NodeMergeRewritePass: Rewriteinfo not found for: " + << n->type_string(); return Status(error::Code::INVALID_ARGUMENT, "Rewrite info not found for the node." "Will skip node rewrite optimization"); } - VLOG(1) << "Rewrite called for: " << n->type_string(); + VLOG(1) << "NodeMergeRewritePass: Rewrite called for: " << n->type_string(); if (n->type_string() == csinfo_.biasaddgrad && ri->node == csinfo_.biasaddgrad && @@ -407,6 +456,7 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) { .Attr("T", T) .Attr("data_format", data_format) .Attr("strides", strides) + .Device(n->def().device()) .Finalize(&**g, &ret)); } else { CHECK_EQ(ri->rewrite, csinfo_.biasaddgrad); @@ -414,6 +464,7 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) { .Input(op) .Attr("T", T) .Attr("data_format", data_format) + .Device(n->def().device()) .Finalize(&**g, &ret)); } @@ -424,7 +475,11 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) { (*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input()); } - VLOG(1) << "Rewrite node: " << n->type_string() << " successful"; + // Copy device assigned to old node to new node. + ret->set_assigned_device_name(n->assigned_device_name()); + + VLOG(1) << "MKLOptimizerMergePass: Rewrote old node:" << n->DebugString() + << ", into node:" << ret->DebugString(); (*g)->RemoveNode(n); return Status::OK(); @@ -450,7 +505,8 @@ NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n, } } - VLOG(1) << "Searching graph for: " << n->type_string() << " in backwards."; + VLOG(1) << "NodeMergeRewritePass: Searching graph for: " + << n->type_string() << " in backwards."; // Now we will check for forward op name for rewrite info in data // flow graph. Get the max hops we should search for the fwd node @@ -473,7 +529,8 @@ NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n, curr_depth = curr_pair.second; CHECK_NOTNULL(curr_node); - VLOG(1) << "Visiting node: " << curr_node->type_string() + VLOG(1) << "NodeMergeRewritePass: Visiting node: " + << curr_node->type_string() << " at depth: " << curr_depth << " for node: " << n->type_string(); @@ -528,17 +585,16 @@ bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) { std::vector<std::pair<Node*, Node*>> nodes_to_be_merged; std::vector<Node*> nodes_to_be_rewritten; - VLOG(1) << "Running NodeMerge Optimization"; - for (Node* n : order) { if (!n->IsOp()) continue; Node* n1 = nullptr; if ((n1 = FindNodeForMerge(n)) != nullptr) { - VLOG(1) << "Scheduled nodes " << n->name() << " and " - << n1->name() << " for merging"; + VLOG(1) << "NodeMergeRewritePass: Scheduled nodes " + << n->name() << " and " << n1->name() << " for merging"; nodes_to_be_merged.push_back(std::make_pair(n, n1)); } else if (IsApplicableRewriteNode(n)) { - VLOG(1) << "Scheduled node " << n->name() << " for rewrite"; + VLOG(1) << "NodeMergeRewritePass: Scheduled node " << n->name() + << " for rewrite"; nodes_to_be_rewritten.push_back(n); } } @@ -549,7 +605,8 @@ bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) { string n1_name = i.first->name(); string n2_name = i.second->name(); if (MergeNode(g, i.first, i.second) == Status::OK()) { - VLOG(1) << "Merged nodes " << n1_name << " and " << n2_name; + VLOG(1) << "NodeMergeRewritePass: Merged nodes " << n1_name + << " and " << n2_name; result = true; } } @@ -559,7 +616,8 @@ bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) { for (Node* i : nodes_to_be_rewritten) { string name = i->name(); if (RewriteNode(g, i) == Status::OK()) { - VLOG(1) << "Rewrite node: " << name << " successful."; + VLOG(1) << "NodeMergeRewritePass: Rewrite node: " + << name << " successful."; result = true; } } @@ -574,8 +632,6 @@ bool OptimizeNodeMerge(std::unique_ptr<Graph>* g) { } Status NodeMergeRewritePass::Run(const GraphOptimizationPassOptions& options) { - // Currently checking only for two cases - Conv2D+Bias and Matmul+Bias. - // It is possible to extend it to other operators in future. if (options.graph == nullptr) { return Status::OK(); } |