diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 167 |
1 files changed, 127 insertions, 40 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index cf5d6e8baa..90377e54c7 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -256,6 +256,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { public: MklLayoutRewritePass() { // NOTE: names are alphabetically sorted. + csinfo_.addn = "AddN"; csinfo_.avg_pool = "AvgPool"; csinfo_.avg_pool_grad = "AvgPoolGrad"; csinfo_.bias_add = "BiasAdd"; @@ -279,17 +280,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = "_MklConv2DWithBiasBackpropBias"; - csinfo_.relu = "Relu"; - csinfo_.relu_grad = "ReluGrad"; - csinfo_.reshape = "Reshape"; - csinfo_.split = "Split"; + csinfo_.relu = "Relu"; + csinfo_.relu_grad = "ReluGrad"; + csinfo_.reshape = "Reshape"; + csinfo_.split = "Split"; + // Element-wise ops. Ensure you also add any new ops to IsOpElementWise + // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the + // MklInputConversion op is added before it. + csinfo_.add = "Add"; + csinfo_.maximum = "Maximum"; + csinfo_.mul = "Mul"; + csinfo_.squared_difference = "SquaredDifference"; + csinfo_.sub = "Sub"; + // End - element-wise ops. See note above. // NOTE: names are alphabetically sorted. + rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, + AddNRewrite, nullptr}); + rinfo_.push_back({csinfo_.add, + mkl_op_registry::GetMklOpName(csinfo_.add), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool, - GetMklOpName(csinfo_.avg_pool), + mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsPooling, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool_grad, - GetMklOpName(csinfo_.avg_pool_grad), + mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), CopyAttrsPooling, AlwaysRewrite, nullptr}); // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending // on if context contains Conv2D. @@ -303,50 +318,62 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsBiasAddGrad, ContextMatchRewrite, &biasaddgrad_matmul_context_}); rinfo_.push_back({csinfo_.concat, - GetMklOpName(csinfo_.concat), + mkl_op_registry::GetMklOpName(csinfo_.concat), CopyAttrsConcat, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.concatv2, - GetMklOpName(csinfo_.concatv2), + mkl_op_registry::GetMklOpName(csinfo_.concatv2), CopyAttrsConcatV2, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d, - GetMklOpName(csinfo_.conv2d), + mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_filter, - GetMklOpName(csinfo_.conv2d_grad_filter), + mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_input, - GetMklOpName(csinfo_.conv2d_grad_input), + mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm, - GetMklOpName(csinfo_.fused_batch_norm), + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm_grad, - GetMklOpName(csinfo_.fused_batch_norm_grad), + mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.identity, - GetMklOpName(csinfo_.identity), + mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsIdentity, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.lrn, - GetMklOpName(csinfo_.lrn), + mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.lrn_grad, - GetMklOpName(csinfo_.lrn_grad), + mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), CopyAttrsLRN, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.max_pool, - GetMklOpName(csinfo_.max_pool), + mkl_op_registry::GetMklOpName(csinfo_.max_pool), CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr}); rinfo_.push_back({csinfo_.max_pool_grad, - GetMklOpName(csinfo_.max_pool_grad), + mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), CopyAttrsPooling, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.maximum, + mkl_op_registry::GetMklOpName(csinfo_.maximum), + CopyAttrsDataType, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.mul, + mkl_op_registry::GetMklOpName(csinfo_.mul), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.relu, - GetMklOpName(csinfo_.relu), - CopyAttrsRelu, AlwaysRewrite, nullptr}); + mkl_op_registry::GetMklOpName(csinfo_.relu), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.relu_grad, - GetMklOpName(csinfo_.relu_grad), - CopyAttrsRelu, AlwaysRewrite, nullptr}); + mkl_op_registry::GetMklOpName(csinfo_.relu_grad), + CopyAttrsDataType, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.reshape, - GetMklOpName(csinfo_.reshape), + mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsReshape, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.squared_difference, + mkl_op_registry::GetMklOpName(csinfo_.squared_difference), + CopyAttrsDataType, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.sub, + mkl_op_registry::GetMklOpName(csinfo_.sub), + CopyAttrsDataType, AlwaysRewrite, nullptr}); // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -429,6 +456,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Structure to store all constant strings /// NOTE: names are alphabetically sorted. typedef struct { + string addn; + string add; string avg_pool; string avg_pool_grad; string bias_add; @@ -446,15 +475,19 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string matmul; string max_pool; string max_pool_grad; + string maximum; string mkl_conv2d; string mkl_conv2d_grad_input; string mkl_conv2d_grad_filter; string mkl_conv2d_with_bias; string mkl_conv2d_with_bias_backprop_bias; + string mul; string relu; string relu_grad; string reshape; string split; + string squared_difference; + string sub; } ConstStringsInfo; private: @@ -502,15 +535,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return N; } - // Get the name of Mkl op from original TensorFlow op - // We prefix 'Mkl' to the original op to get Mkl op. - // TODO(nhasabni) We should move this to mkl_util.h. - inline string GetMklOpName(const string& name) const { - // Prefix that we add to Tensorflow op name to construct Mkl op name. - const char* const kMklOpPrefix = "_Mkl"; - return string(kMklOpPrefix) + name; - } - // Can op represented by node 'n' run on DEVICE_CPU? // Op can run on CPU with MKL if the runtime assigned device or the // user requested device contains device CPU, or both are empty. @@ -604,6 +628,19 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } + static bool AddNRewrite(const Node* n, const ContextInfo* c) { + CHECK_NOTNULL(n); + + int num; + CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true); + + // Condition that specifies non-batch-wise and non-depth-wise pooling. + if (num == 2) { + return true; + } + + return false; + } // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node // specified in contextinfo 'ci'. Function updates fwd_node to point // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias. @@ -907,15 +944,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // We need operator-specific function to copy attributes because the framework // does not provide any generic function for it. // NOTE: names are alphabetically sorted. + static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); @@ -1334,7 +1372,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), T)) { + mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an // edge from this node's fwd_slot to bwdop's bwd_slot. If there is // an edge, then we just add an attribute on this node for setting @@ -1360,7 +1398,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( nb->Attr("workspace_enabled", false); } } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), + mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this @@ -1376,7 +1414,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( if (e->src_output() == ws.fwd_slot && // We would have rewritten the forward op, so we need to use // GetMklOpName call to get its Mkl name. - e->src()->type_string() == GetMklOpName(ws.fwd_op) && + e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) && e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); CHECK_NOTNULL(ws_tensors); @@ -1455,6 +1493,20 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); } +void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + int N; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("N", N); +} + void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb) { DataType T; @@ -1527,8 +1579,8 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, nb->Attr("data_format", data_format); } -void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node, - NodeBuilder* nb) { +void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, + NodeBuilder* nb) { DataType T; // Get all attributes from old node. @@ -1894,7 +1946,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, } // Get all inputs. - const int num_inputs = orig_node->in_edges().size(); + int num_inputs = orig_node->in_edges().size(); + + // Drop count for control edges from inputs + for (const Edge* e : orig_node->in_edges()) { + if (e->IsControlEdge()) { + num_inputs--; + } + } + gtl::InlinedVector<Node*, 4> control_edges; gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); FillInputs(orig_node, &control_edges, &inputs); @@ -2008,7 +2068,34 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // BiasAddGrad is not an Mkl layer, so we make an exception for it. if (n->type_string() != csinfo_.bias_add_grad) { - if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) { + if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { + return nullptr; + } + } + + // For elementwise node, we reuse the Eigen implementation and pass the MKL + // metadata tensor through so we can avoid conversions. However, if all + // incoming edges are in TF format, we don't need all this overhead, so + // replace the elementwise node only if at least one of its parents is a MKL + // node. + // + // TODO(vrane): Add implementation for element-wise ops that doesn't reuse + // eigen code to reduce cross-library dependency. + if (mkl_op_registry::IsMklElementWiseOp( + mkl_op_registry::GetMklOpName(n->type_string()), T)) { + bool incoming_mkl_edge = false; + for (auto parent : n->in_edges()) { + if (mkl_op_registry::IsMklOp( + mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) { + incoming_mkl_edge = true; + break; + } else { + VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string(); + } + } + if (incoming_mkl_edge == false) { + VLOG(1) << "Skipping replacement of elementwise node which has no MKL " + "parents."; return nullptr; } } |