diff options
author | 2018-03-15 22:04:33 -0700 | |
---|---|---|
committer | 2018-03-15 22:04:33 -0700 | |
commit | 6485bb7029c6d856c7ffa744168a8864ef6c986c (patch) | |
tree | 6edbf2b599515a6b85339a8f142182a54dd7ffb2 | |
parent | a6f8b220638484c0b6e54f3a7d445c155f578535 (diff) |
MKL DNN: fix the TF1.6 speed issue by fixing MKL DNN LRN taking the optimum path (#17605) (#17751)
* MKL DNN: fix the TF1.6 speed issue by fixing MKL DNN LRN
* fixed typos in the doc for LrnRewrite
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 02038c5d77..568fc87e65 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2492,10 +2492,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsDataType, AlwaysRewrite}); rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), - CopyAttrsLRN, AlwaysRewrite}); + CopyAttrsLRN, LrnRewrite}); rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), - CopyAttrsLRN, AlwaysRewrite}); + CopyAttrsLRN, LrnRewrite}); rinfo_.push_back({csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool), CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); @@ -2865,6 +2865,28 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } + // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized + // path. The unoptimized path is slow. Thus we dont rewrite the node + // and use default Eigen. But for depth_radius=2, MKL DNN optimized + // path is taken, i.e., eigen node is rewritten by MKl DNN node. + static bool LrnRewrite(const Node* n) { + CHECK_NOTNULL(n); + + int depth_radius; + CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true); + + // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN + // and use eigen node instead + if (depth_radius == 2) { + return true; + } + VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which" + << "case is not optimized by Intel MKL, thus using Eigen op" + << "for LRN " ; + + return false; + } + static bool AddNRewrite(const Node* n) { CHECK_NOTNULL(n); |