diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 148 |
1 files changed, 141 insertions, 7 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 72a13d4da7..b9667998d6 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2691,14 +2691,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // If Op has been specifically assigned to a non-CPU device, then No. if (!n->assigned_device_name().empty() && - !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) { + !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) { result = false; reason = "Op has been assigned a runtime device that is not CPU."; } // If user has specifically assigned this op to a non-CPU device, then No. if (!n->def().device().empty() && - !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) { + !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) { result = false; reason = "User has assigned a device that is not CPU."; } @@ -2865,9 +2865,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } - // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized - // path. The unoptimized path is slow. Thus we dont rewrite the node - // and use default Eigen. But for depth_radius=2, MKL DNN optimized + // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized + // path. The unoptimized path is slow. Thus we dont rewrite the node + // and use default Eigen. But for depth_radius=2, MKL DNN optimized // path is taken, i.e., eigen node is rewritten by MKl DNN node. static bool LrnRewrite(const Node* n) { CHECK_NOTNULL(n); @@ -2876,13 +2876,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true); // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN - // and use eigen node instead + // and use eigen node instead if (depth_radius == 2) { return true; } VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which" << "case is not optimized by Intel MKL, thus using Eigen op" - << "for LRN " ; + << "for LRN "; return false; } @@ -3015,6 +3015,35 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added); + // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge + // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph + // 'g'. Returns true is fixup was done; otherwise, it returns false. + bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, + const Edge* e_data, const Edge* e_metadata); + + // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly + // connected? If not, then fix them. This is needed because a graph may have + // some input Mkl metadata edges incorrectly setup after node merge and + // rewrite passes. This could happen because GetReversePostOrder function may + // not provide topologically sorted order if a graph contains cycles. The + // function returns true if at least one Mkl metadata edge for node 'n' was + // fixed. Otherwise, it returns false. + // + // Example: + // + // X = MklConv2D(_, _, _) + // Y = MklConv2DWithBias(_, _, _, _, _, _) + // Z = MklAdd(X, Y, DummyMklTensor, Y:1) + // + // For a graph such as shown above, note that 3rd argument of MklAdd contains + // DummyMklTensor. Actually, it should be getting the Mkl metadata from + // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible + // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X + // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl + // metadata edges only - it does not rewrite nodes nor does it modify the Mkl + // data edges (1st and 2nd arguments of MklAdd). + bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n); + // Functions specific to operators to copy attributes // We need operator-specific function to copy attributes because the framework // does not provide any generic function for it. @@ -4242,6 +4271,92 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { } /////////////////////////////////////////////////////////////////////////////// +// Post-rewrite Mkl metadata fixup pass +/////////////////////////////////////////////////////////////////////////////// +bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, + const Edge* e_data, const Edge* e_metadata) { + if (g == nullptr || e_data == nullptr || e_metadata == nullptr) { + return false; + } + + Node* n_data = e_data->src(); + int n_data_op_slot = e_data->src_output(); + int n_metadata_op_slot = GetTensorMetaDataIndex(n_data_op_slot, + n_data->num_outputs()); + + // If the source of meta edge is a constant node (producing dummy Mkl metadata + // tensor), then we will need to fix. + if (IsConstant(e_metadata->src())) { + Node* e_metadata_dst = e_metadata->dst(); + int e_metadata_in_slot = e_metadata->dst_input(); + CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot, + e_metadata_dst, e_metadata_in_slot)); + + (*g)->RemoveEdge(e_metadata); + return true; + } + + return false; +} + +bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g, + Node* n) { + bool result = false; + + // If graph node is not Mkl node, then return. + DataType T = DT_INVALID; + if (!GetNodeAttr(n->def(), "T", &T).ok() || + !mkl_op_registry::IsMklOp(n->type_string(), T)) { + return result; + } + + // If it is Mkl node, then check if the input edges to this node that carry + // Mkl metadata are linked up correctly with the source node. + + // For Mkl nodes, we generate twice the number of input tensors (n for Mkl + // data tensors + n for Mkl metadata tensors). We need to check for correct + // connection of n metadata tensors only. + int num_data_inputs = n->num_inputs() / 2; + for (int idx = 0; idx < num_data_inputs; idx++) { + // Get the edge connecting input slot with index (idx). + const Edge* e = nullptr; + TF_CHECK_OK(n->input_edge(idx, &e)); + + // If e is control edge, then skip. + if (e->IsControlEdge()) { + continue; + } + + // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl + // node, then we don't need to do anything. + Node* e_src = e->src(); + if (GetNodeAttr(e_src->def(), "T", &T).ok() && + mkl_op_registry::IsMklOp(e_src->type_string(), T)) { + // Source node for edge 'e' is Mkl node. + // Destination node and destination input slot of e is node 'n' and 'idx' + // resp. + CHECK_EQ(e->dst(), n); + CHECK_EQ(e->dst_input(), idx); + + // Let's get edge that carries Mkl metadata corresponding to Mkl data edge + // 'e'. For that, let's first get the input slot of 'n' where the meta + // edge will feed the value. + int e_meta_in_slot = GetTensorMetaDataIndex(e->dst_input(), + n->num_inputs()); + const Edge* e_meta = nullptr; + TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta)); + + // Let's check if we need to fix this meta edge. + if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) { + result = true; + } + } + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// // Run function for the pass /////////////////////////////////////////////////////////////////////////////// @@ -4307,6 +4422,25 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g); + order.clear(); + GetReversePostOrder(**g, &order); // This will give us topological sort. + for (Node* n : order) { + // If node is not an op or it cannot run on CPU device, then skip. + if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { + continue; + } + if (FixMklMetaDataEdges(g, n)) { + string node_name = n->name(); + string op_name = n->type_string(); + + VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node " + << node_name << " with op " << op_name; + result = true; + } + } + DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)", + &**g); + return result; } |