aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc148
1 files changed, 141 insertions, 7 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 72a13d4da7..b9667998d6 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2691,14 +2691,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
- !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}
// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
- !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
@@ -2865,9 +2865,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
- // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
- // path. The unoptimized path is slow. Thus we dont rewrite the node
- // and use default Eigen. But for depth_radius=2, MKL DNN optimized
+ // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
+ // path. The unoptimized path is slow. Thus we dont rewrite the node
+ // and use default Eigen. But for depth_radius=2, MKL DNN optimized
// path is taken, i.e., eigen node is rewritten by MKl DNN node.
static bool LrnRewrite(const Node* n) {
CHECK_NOTNULL(n);
@@ -2876,13 +2876,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true);
// if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
- // and use eigen node instead
+ // and use eigen node instead
if (depth_radius == 2) {
return true;
}
VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
<< "case is not optimized by Intel MKL, thus using Eigen op"
- << "for LRN " ;
+ << "for LRN ";
return false;
}
@@ -3015,6 +3015,35 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<NodeBuilder::NodeOut>* ws_tensors,
bool* are_ws_tensors_added);
+ // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
+ // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
+ // 'g'. Returns true is fixup was done; otherwise, it returns false.
+ bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
+ const Edge* e_data, const Edge* e_metadata);
+
+ // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
+ // connected? If not, then fix them. This is needed because a graph may have
+ // some input Mkl metadata edges incorrectly setup after node merge and
+ // rewrite passes. This could happen because GetReversePostOrder function may
+ // not provide topologically sorted order if a graph contains cycles. The
+ // function returns true if at least one Mkl metadata edge for node 'n' was
+ // fixed. Otherwise, it returns false.
+ //
+ // Example:
+ //
+ // X = MklConv2D(_, _, _)
+ // Y = MklConv2DWithBias(_, _, _, _, _, _)
+ // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
+ //
+ // For a graph such as shown above, note that 3rd argument of MklAdd contains
+ // DummyMklTensor. Actually, it should be getting the Mkl metadata from
+ // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
+ // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
+ // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
+ // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
+ // data edges (1st and 2nd arguments of MklAdd).
+ bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
+
// Functions specific to operators to copy attributes
// We need operator-specific function to copy attributes because the framework
// does not provide any generic function for it.
@@ -4242,6 +4271,92 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
}
///////////////////////////////////////////////////////////////////////////////
+// Post-rewrite Mkl metadata fixup pass
+///////////////////////////////////////////////////////////////////////////////
+bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
+ const Edge* e_data, const Edge* e_metadata) {
+ if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
+ return false;
+ }
+
+ Node* n_data = e_data->src();
+ int n_data_op_slot = e_data->src_output();
+ int n_metadata_op_slot = GetTensorMetaDataIndex(n_data_op_slot,
+ n_data->num_outputs());
+
+ // If the source of meta edge is a constant node (producing dummy Mkl metadata
+ // tensor), then we will need to fix.
+ if (IsConstant(e_metadata->src())) {
+ Node* e_metadata_dst = e_metadata->dst();
+ int e_metadata_in_slot = e_metadata->dst_input();
+ CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot,
+ e_metadata_dst, e_metadata_in_slot));
+
+ (*g)->RemoveEdge(e_metadata);
+ return true;
+ }
+
+ return false;
+}
+
+bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
+ Node* n) {
+ bool result = false;
+
+ // If graph node is not Mkl node, then return.
+ DataType T = DT_INVALID;
+ if (!GetNodeAttr(n->def(), "T", &T).ok() ||
+ !mkl_op_registry::IsMklOp(n->type_string(), T)) {
+ return result;
+ }
+
+ // If it is Mkl node, then check if the input edges to this node that carry
+ // Mkl metadata are linked up correctly with the source node.
+
+ // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
+ // data tensors + n for Mkl metadata tensors). We need to check for correct
+ // connection of n metadata tensors only.
+ int num_data_inputs = n->num_inputs() / 2;
+ for (int idx = 0; idx < num_data_inputs; idx++) {
+ // Get the edge connecting input slot with index (idx).
+ const Edge* e = nullptr;
+ TF_CHECK_OK(n->input_edge(idx, &e));
+
+ // If e is control edge, then skip.
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
+ // node, then we don't need to do anything.
+ Node* e_src = e->src();
+ if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
+ mkl_op_registry::IsMklOp(e_src->type_string(), T)) {
+ // Source node for edge 'e' is Mkl node.
+ // Destination node and destination input slot of e is node 'n' and 'idx'
+ // resp.
+ CHECK_EQ(e->dst(), n);
+ CHECK_EQ(e->dst_input(), idx);
+
+ // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
+ // 'e'. For that, let's first get the input slot of 'n' where the meta
+ // edge will feed the value.
+ int e_meta_in_slot = GetTensorMetaDataIndex(e->dst_input(),
+ n->num_inputs());
+ const Edge* e_meta = nullptr;
+ TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
+
+ // Let's check if we need to fix this meta edge.
+ if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
+ result = true;
+ }
+ }
+ }
+
+ return result;
+}
+
+///////////////////////////////////////////////////////////////////////////////
// Run function for the pass
///////////////////////////////////////////////////////////////////////////////
@@ -4307,6 +4422,25 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
+ order.clear();
+ GetReversePostOrder(**g, &order); // This will give us topological sort.
+ for (Node* n : order) {
+ // If node is not an op or it cannot run on CPU device, then skip.
+ if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
+ continue;
+ }
+ if (FixMklMetaDataEdges(g, n)) {
+ string node_name = n->name();
+ string op_name = n->type_string();
+
+ VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
+ << node_name << " with op " << op_name;
+ result = true;
+ }
+ }
+ DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
+ &**g);
+
return result;
}