aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc167
1 files changed, 127 insertions, 40 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index cf5d6e8baa..90377e54c7 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -256,6 +256,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
public:
MklLayoutRewritePass() {
// NOTE: names are alphabetically sorted.
+ csinfo_.addn = "AddN";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
csinfo_.bias_add = "BiasAdd";
@@ -279,17 +280,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
"_MklConv2DWithBiasBackpropBias";
- csinfo_.relu = "Relu";
- csinfo_.relu_grad = "ReluGrad";
- csinfo_.reshape = "Reshape";
- csinfo_.split = "Split";
+ csinfo_.relu = "Relu";
+ csinfo_.relu_grad = "ReluGrad";
+ csinfo_.reshape = "Reshape";
+ csinfo_.split = "Split";
+ // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
+ // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
+ // MklInputConversion op is added before it.
+ csinfo_.add = "Add";
+ csinfo_.maximum = "Maximum";
+ csinfo_.mul = "Mul";
+ csinfo_.squared_difference = "SquaredDifference";
+ csinfo_.sub = "Sub";
+ // End - element-wise ops. See note above.
// NOTE: names are alphabetically sorted.
+ rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN,
+ AddNRewrite, nullptr});
+ rinfo_.push_back({csinfo_.add,
+ mkl_op_registry::GetMklOpName(csinfo_.add),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool,
- GetMklOpName(csinfo_.avg_pool),
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
CopyAttrsPooling, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool_grad,
- GetMklOpName(csinfo_.avg_pool_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite, nullptr});
// BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
// on if context contains Conv2D.
@@ -303,50 +318,62 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsBiasAddGrad, ContextMatchRewrite,
&biasaddgrad_matmul_context_});
rinfo_.push_back({csinfo_.concat,
- GetMklOpName(csinfo_.concat),
+ mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.concatv2,
- GetMklOpName(csinfo_.concatv2),
+ mkl_op_registry::GetMklOpName(csinfo_.concatv2),
CopyAttrsConcatV2, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d,
- GetMklOpName(csinfo_.conv2d),
+ mkl_op_registry::GetMklOpName(csinfo_.conv2d),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
- GetMklOpName(csinfo_.conv2d_grad_filter),
+ mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_input,
- GetMklOpName(csinfo_.conv2d_grad_input),
+ mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm,
- GetMklOpName(csinfo_.fused_batch_norm),
+ mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
- GetMklOpName(csinfo_.fused_batch_norm_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.identity,
- GetMklOpName(csinfo_.identity),
+ mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsIdentity, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn,
- GetMklOpName(csinfo_.lrn),
+ mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.lrn_grad,
- GetMklOpName(csinfo_.lrn_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsLRN, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.max_pool,
- GetMklOpName(csinfo_.max_pool),
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
rinfo_.push_back({csinfo_.max_pool_grad,
- GetMklOpName(csinfo_.max_pool_grad),
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.maximum,
+ mkl_op_registry::GetMklOpName(csinfo_.maximum),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.mul,
+ mkl_op_registry::GetMklOpName(csinfo_.mul),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu,
- GetMklOpName(csinfo_.relu),
- CopyAttrsRelu, AlwaysRewrite, nullptr});
+ mkl_op_registry::GetMklOpName(csinfo_.relu),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.relu_grad,
- GetMklOpName(csinfo_.relu_grad),
- CopyAttrsRelu, AlwaysRewrite, nullptr});
+ mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.reshape,
- GetMklOpName(csinfo_.reshape),
+ mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.squared_difference,
+ mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.sub,
+ mkl_op_registry::GetMklOpName(csinfo_.sub),
+ CopyAttrsDataType, AlwaysRewrite, nullptr});
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
@@ -429,6 +456,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Structure to store all constant strings
/// NOTE: names are alphabetically sorted.
typedef struct {
+ string addn;
+ string add;
string avg_pool;
string avg_pool_grad;
string bias_add;
@@ -446,15 +475,19 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string matmul;
string max_pool;
string max_pool_grad;
+ string maximum;
string mkl_conv2d;
string mkl_conv2d_grad_input;
string mkl_conv2d_grad_filter;
string mkl_conv2d_with_bias;
string mkl_conv2d_with_bias_backprop_bias;
+ string mul;
string relu;
string relu_grad;
string reshape;
string split;
+ string squared_difference;
+ string sub;
} ConstStringsInfo;
private:
@@ -502,15 +535,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return N;
}
- // Get the name of Mkl op from original TensorFlow op
- // We prefix 'Mkl' to the original op to get Mkl op.
- // TODO(nhasabni) We should move this to mkl_util.h.
- inline string GetMklOpName(const string& name) const {
- // Prefix that we add to Tensorflow op name to construct Mkl op name.
- const char* const kMklOpPrefix = "_Mkl";
- return string(kMklOpPrefix) + name;
- }
-
// Can op represented by node 'n' run on DEVICE_CPU?
// Op can run on CPU with MKL if the runtime assigned device or the
// user requested device contains device CPU, or both are empty.
@@ -604,6 +628,19 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
+ static bool AddNRewrite(const Node* n, const ContextInfo* c) {
+ CHECK_NOTNULL(n);
+
+ int num;
+ CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
+
+ // Condition that specifies non-batch-wise and non-depth-wise pooling.
+ if (num == 2) {
+ return true;
+ }
+
+ return false;
+ }
// Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node
// specified in contextinfo 'ci'. Function updates fwd_node to point
// to Conv2DWithBias node if 'n' is associated with Conv2DWithBias.
@@ -907,15 +944,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// We need operator-specific function to copy attributes because the framework
// does not provide any generic function for it.
// NOTE: names are alphabetically sorted.
+ static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
@@ -1334,7 +1372,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) {
if (orig_node->type_string() == ws.fwd_op &&
- mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), T)) {
+ mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
// an edge, then we just add an attribute on this node for setting
@@ -1360,7 +1398,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
nb->Attr("workspace_enabled", false);
}
} else if (orig_node->type_string() == ws.bwd_op &&
- mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()),
+ mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(orig_node->type_string()),
T)) {
// If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this
@@ -1376,7 +1414,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
if (e->src_output() == ws.fwd_slot &&
// We would have rewritten the forward op, so we need to use
// GetMklOpName call to get its Mkl name.
- e->src()->type_string() == GetMklOpName(ws.fwd_op) &&
+ e->src()->type_string() == mkl_op_registry::GetMklOpName(ws.fwd_op) &&
e->dst_input() == ws.bwd_slot) {
nb->Attr("workspace_enabled", true);
CHECK_NOTNULL(ws_tensors);
@@ -1455,6 +1493,20 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
+void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ int N;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
+
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("N", N);
+}
+
void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
@@ -1527,8 +1579,8 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
nb->Attr("data_format", data_format);
}
-void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
- NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
// Get all attributes from old node.
@@ -1894,7 +1946,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
}
// Get all inputs.
- const int num_inputs = orig_node->in_edges().size();
+ int num_inputs = orig_node->in_edges().size();
+
+ // Drop count for control edges from inputs
+ for (const Edge* e : orig_node->in_edges()) {
+ if (e->IsControlEdge()) {
+ num_inputs--;
+ }
+ }
+
gtl::InlinedVector<Node*, 4> control_edges;
gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
FillInputs(orig_node, &control_edges, &inputs);
@@ -2008,7 +2068,34 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
// BiasAddGrad is not an Mkl layer, so we make an exception for it.
if (n->type_string() != csinfo_.bias_add_grad) {
- if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
+ if (!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) {
+ return nullptr;
+ }
+ }
+
+ // For elementwise node, we reuse the Eigen implementation and pass the MKL
+ // metadata tensor through so we can avoid conversions. However, if all
+ // incoming edges are in TF format, we don't need all this overhead, so
+ // replace the elementwise node only if at least one of its parents is a MKL
+ // node.
+ //
+ // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
+ // eigen code to reduce cross-library dependency.
+ if (mkl_op_registry::IsMklElementWiseOp(
+ mkl_op_registry::GetMklOpName(n->type_string()), T)) {
+ bool incoming_mkl_edge = false;
+ for (auto parent : n->in_edges()) {
+ if (mkl_op_registry::IsMklOp(
+ mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) {
+ incoming_mkl_edge = true;
+ break;
+ } else {
+ VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string();
+ }
+ }
+ if (incoming_mkl_edge == false) {
+ VLOG(1) << "Skipping replacement of elementwise node which has no MKL "
+ "parents.";
return nullptr;
}
}