aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc81
1 files changed, 40 insertions, 41 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index b9667998d6..c22e0a3872 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <memory>
#include <queue>
#include <set>
-#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
@@ -2495,13 +2494,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsLRN, LrnRewrite});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
- CopyAttrsLRN, LrnRewrite});
+ CopyAttrsLRN, LrnGradRewrite});
rinfo_.push_back({csinfo_.max_pool,
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
- CopyAttrsPooling, AlwaysRewrite});
+ CopyAttrsPooling, MaxpoolGradRewrite});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
@@ -2887,6 +2886,41 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
+ static bool LrnGradRewrite(const Node* n) {
+ CHECK_NOTNULL(n);
+ bool do_rewrite = false;
+
+ for (const Edge* e : n->in_edges()) {
+ // Rewrite only if there is corresponding LRN, i.e workspace is available
+ if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
+ e->src_output() == 0) {
+ do_rewrite = true;
+ break;
+ }
+ }
+ return do_rewrite;
+ }
+
+ static bool MaxpoolGradRewrite(const Node* n) {
+ CHECK_NOTNULL(n);
+ bool do_rewrite = false;
+ for (const Edge* e : n->in_edges()) {
+ // Rewrite only if there is corresponding Maxpool, i.e workspace is
+ // available
+ if (e->dst()->type_string() == csinfo_.max_pool_grad &&
+ e->dst_input() == 1 &&
+ e->src()->type_string() ==
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
+ e->src_output() == 0) {
+ do_rewrite = true;
+ break;
+ }
+ }
+ return do_rewrite;
+ }
+
static bool AddNRewrite(const Node* n) {
CHECK_NOTNULL(n);
@@ -3421,44 +3455,9 @@ Status MklLayoutRewritePass::SetUpInputs(
// TODO(nhasabni) We should move this to mkl_util.h.
void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
- // We use a tensor of shape {1} and value 0 to represent
- // dummy float tensor. We need this as a dummy workspace tensor.
- // Workspace tensor has type uint8.
- const DataType dt = DataTypeToEnum<uint8>::v();
- TensorProto proto;
- proto.set_dtype(dt);
- float zero[1] = {0};
- proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 4));
- TensorShape dummy_shape({1});
- dummy_shape.AsProto(proto.mutable_tensor_shape());
- TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // same the device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
-
- // If number of inputs to the original node is > 0, then we add
- // control dependency between 1st input (index 0) of the original node and
- // the dummy Mkl node. This is needed because control-flow ops such as Enter,
- // Merge, etc, require frame_name of the dummy Mkl node to be same as the
- // rewritten node. Adding control edge between 1st input of the original node
- // and the dummy Mkl node ensures that the dummy node is in the same frame
- // as the original node. Choosing 1st input is not necessary - any input of
- // the original node is fine because all the inputs of a node are always in
- // the same frame.
- if (orig_node->num_inputs() > 0) {
- Node* orig_input0 = nullptr;
- TF_CHECK_OK(
- orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
- // Allow duplicate while adding control edge as it would fail (return
- // NULL) if we try to add duplicate edge.
- CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
- }
-
- (*out)->set_assigned_device_name(orig_node->assigned_device_name());
+ // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
+ // workspace tensor.
+ GetDummyMklTensorNode(g, out, orig_node);
}
void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(