diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 11:04:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 11:04:15 -0700 |
commit | 1f7fde6ccaf7ff1fee530c816e6df919c561a2ce (patch) | |
tree | ae1f7fba49e8e45fb2dbde09f4f543398950c514 /tensorflow/core/graph/mkl_layout_pass.cc | |
parent | 50ba36f1662dc61cb1b60353a2a09aa3ea72bb59 (diff) | |
parent | f565cdeef92861eb70b91c36460d0130254f2c91 (diff) |
Merge pull request #21007 from Intel-tensorflow:agramesh/parallel_for_fix
PiperOrigin-RevId: 206611194
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 80 |
1 files changed, 40 insertions, 40 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 3e769b5303..c22e0a3872 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2494,13 +2494,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsLRN, LrnRewrite}); rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), - CopyAttrsLRN, LrnRewrite}); + CopyAttrsLRN, LrnGradRewrite}); rinfo_.push_back({csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool), CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); rinfo_.push_back({csinfo_.max_pool_grad, mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), - CopyAttrsPooling, AlwaysRewrite}); + CopyAttrsPooling, MaxpoolGradRewrite}); rinfo_.push_back({csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), @@ -2886,6 +2886,41 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } + static bool LrnGradRewrite(const Node* n) { + CHECK_NOTNULL(n); + bool do_rewrite = false; + + for (const Edge* e : n->in_edges()) { + // Rewrite only if there is corresponding LRN, i.e workspace is available + if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(csinfo_.lrn) && + e->src_output() == 0) { + do_rewrite = true; + break; + } + } + return do_rewrite; + } + + static bool MaxpoolGradRewrite(const Node* n) { + CHECK_NOTNULL(n); + bool do_rewrite = false; + for (const Edge* e : n->in_edges()) { + // Rewrite only if there is corresponding Maxpool, i.e workspace is + // available + if (e->dst()->type_string() == csinfo_.max_pool_grad && + e->dst_input() == 1 && + e->src()->type_string() == + mkl_op_registry::GetMklOpName(csinfo_.max_pool) && + e->src_output() == 0) { + do_rewrite = true; + break; + } + } + return do_rewrite; + } + static bool AddNRewrite(const Node* n) { CHECK_NOTNULL(n); @@ -3420,44 +3455,9 @@ Status MklLayoutRewritePass::SetUpInputs( // TODO(nhasabni) We should move this to mkl_util.h. void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { - // We use a tensor of shape {1} and value 0 to represent - // dummy float tensor. We need this as a dummy workspace tensor. - // Workspace tensor has type uint8. - const DataType dt = DataTypeToEnum<uint8>::v(); - TensorProto proto; - proto.set_dtype(dt); - float zero[1] = {0}; - proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 4)); - TensorShape dummy_shape({1}); - dummy_shape.AsProto(proto.mutable_tensor_shape()); - TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); - - // If number of inputs to the original node is > 0, then we add - // control dependency between 1st input (index 0) of the original node and - // the dummy Mkl node. This is needed because control-flow ops such as Enter, - // Merge, etc, require frame_name of the dummy Mkl node to be same as the - // rewritten node. Adding control edge between 1st input of the original node - // and the dummy Mkl node ensures that the dummy node is in the same frame - // as the original node. Choosing 1st input is not necessary - any input of - // the original node is fine because all the inputs of a node are always in - // the same frame. - if (orig_node->num_inputs() > 0) { - Node* orig_input0 = nullptr; - TF_CHECK_OK( - orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); - // Allow duplicate while adding control edge as it would fail (return - // NULL) if we try to add duplicate edge. - CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); - } - - (*out)->set_assigned_device_name(orig_node->assigned_device_name()); + // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent + // workspace tensor. + GetDummyMklTensorNode(g, out, orig_node); } void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( |