diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 595 |
1 files changed, 385 insertions, 210 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 09b632a165..94741a11ff 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/util/mkl_util.h" @@ -272,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; + csinfo_.identity = "Identity"; csinfo_.lrn = "LRN"; csinfo_.lrn_grad = "LRNGrad"; csinfo_.matmul = "MatMul"; @@ -280,51 +282,75 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_conv2d = "_MklConv2D"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = - "_MklConv2DWithBiasBackpropBias"; - csinfo_.relu = "Relu"; - csinfo_.reshape = "Reshape"; - csinfo_.relu_grad = "ReluGrad"; - csinfo_.split = "Split"; + "_MklConv2DWithBiasBackpropBias"; + csinfo_.relu = "Relu"; + csinfo_.relu_grad = "ReluGrad"; + csinfo_.reshape = "Reshape"; + csinfo_.split = "Split"; // NOTE: names are alphabetically sorted. - rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1, - CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.avg_pool, + GetMklOpName(csinfo_.avg_pool), + CopyAttrsPooling, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.avg_pool_grad, - GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling, - AlwaysRewrite}); - rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0, - CopyAttrsConcat, AlwaysRewrite}); - rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0, - CopyAttrsConcatV2, AlwaysRewrite}); - rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2, - CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.avg_pool_grad), + CopyAttrsPooling, AlwaysRewrite, nullptr}); + // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending + // on if context contains Conv2D. + rinfo_.push_back({csinfo_.bias_add_grad, + csinfo_.mkl_conv2d_with_bias_backprop_bias, + CopyAttrsBiasAddGrad, ContextMatchRewrite, + &biasaddgrad_conv2dwithbias_context_}); + // BiasAddGrad gets written into BiasAddGrad depending on if context + // contains MatMul. + rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul, + CopyAttrsBiasAddGrad, ContextMatchRewrite, + &biasaddgrad_matmul_context_}); + rinfo_.push_back({csinfo_.concat, + GetMklOpName(csinfo_.concat), + CopyAttrsConcat, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.concatv2, + GetMklOpName(csinfo_.concatv2), + CopyAttrsConcatV2, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.conv2d, + GetMklOpName(csinfo_.conv2d), + CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_filter, - GetMklOpName(csinfo_.conv2d_grad_filter), 3, - CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.conv2d_grad_filter), + CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.conv2d_grad_input, - GetMklOpName(csinfo_.conv2d_grad_input), 3, - CopyAttrsConv2D, AlwaysRewrite}); + GetMklOpName(csinfo_.conv2d_grad_input), + CopyAttrsConv2D, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm, - GetMklOpName(csinfo_.fused_batch_norm), 5, - CopyAttrsFusedBatchNorm, AlwaysRewrite}); + GetMklOpName(csinfo_.fused_batch_norm), + CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); rinfo_.push_back({csinfo_.fused_batch_norm_grad, - GetMklOpName(csinfo_.fused_batch_norm_grad), 5, - CopyAttrsFusedBatchNorm, AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN, - AlwaysRewrite}); - rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3, - CopyAttrsLRN, AlwaysRewrite}); - rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1, - CopyAttrsPooling, AlwaysRewrite}); + GetMklOpName(csinfo_.fused_batch_norm_grad), + CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.identity, + GetMklOpName(csinfo_.identity), + CopyAttrsIdentity, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.lrn, + GetMklOpName(csinfo_.lrn), + CopyAttrsLRN, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.lrn_grad, + GetMklOpName(csinfo_.lrn_grad), + CopyAttrsLRN, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.max_pool, + GetMklOpName(csinfo_.max_pool), + CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr}); rinfo_.push_back({csinfo_.max_pool_grad, - GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling, - AlwaysRewrite}); - rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1, - CopyAttrsRelu, AlwaysRewrite}); - rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2, - CopyAttrsReshape, AlwaysRewrite}); - - // TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet. + GetMklOpName(csinfo_.max_pool_grad), + CopyAttrsPooling, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.relu, + GetMklOpName(csinfo_.relu), + CopyAttrsRelu, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.relu_grad, + GetMklOpName(csinfo_.relu_grad), + CopyAttrsRelu, AlwaysRewrite, nullptr}); + rinfo_.push_back({csinfo_.reshape, + GetMklOpName(csinfo_.reshape), + CopyAttrsReshape, AlwaysRewrite, nullptr}); // Add info about which ops to add workspace edge to and the slots. wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); @@ -338,8 +364,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // maxhops in backward data-flow graph. Since input of forward nodes // (Conv2D) directly goes to backward nodes, we do not expect the // hop-distance would be more than few nodes. - cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, - kNodeMergeContextMaxDepth}); + biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, + kNodeMergeContextMaxDepth}; + + biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, + csinfo_.mkl_conv2d_with_bias, + kNodeMergeContextMaxDepth}; + + cinfo_.push_back(&biasaddgrad_matmul_context_); + cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); } // Standard interface to run pass @@ -354,7 +387,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @return true, if and only if graph is mutated; false otherwise. bool RunPass(std::unique_ptr<Graph>* g); - private: + /// Structure to specify the context information used in a node rewrite rule + typedef struct { + string node; // Name of the node to be rewritten + string fwd; // Name of the node in the forward pass that this node + // corresponds to + size_t max_hop; // Maximum number of hops the fwd is located + // from this node. If the fwd is farther than max_hop + // then we do not rewrite the node. + } ContextInfo; + /// Structure to specify the name of an original node, its new name after /// rewrite, the number of inputs to the original node, the function to /// be used to copy attributes for the op, and the rule (if any) which @@ -362,11 +404,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { typedef struct { string name; // Original name of op of the node in the graph string new_name; // New name of the op of the node in the graph - int num_ins; // The number of inputs to the original op type // A function handler to copy attributes from an old node to a new node. std::function<void(const Node*, NodeBuilder*)> copy_attrs; - std::function<bool(const Node*)> rewrite_rule; // A rule under which to - // rewrite this node. + // A rule under which to rewrite this node + std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule; + // ContextInfo, if any, to be used for rewrite + ContextInfo* context; } RewriteInfo; /// Structure to specify a forward op, a backward op, and the slot numbers @@ -393,16 +436,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string new_node; // Name of the node after merge } MergeInfo; - /// Structure to specify the context information used in a node rewrite rule - typedef struct { - string node; // Name of the node to be rewritten - string fwd; // Name of the node in the forward pass that this node - // corresponds to - size_t max_hop; // Maximum number of hops the fwd is located - // from this node. If the fwd is farther than max_hop - // then we do not rewrite the node. - } ContextInfo; - /// Structure to store all constant strings /// NOTE: names are alphabetically sorted. struct { @@ -417,6 +450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string conv2d_grad_filter; string fused_batch_norm; string fused_batch_norm_grad; + string identity; string lrn; string lrn_grad; string matmul; @@ -427,10 +461,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string mkl_conv2d_with_bias_backprop_bias; string relu; string relu_grad; - string split; string reshape; + string split; } csinfo_; + private: /// Maintain info about nodes to rewrite std::vector<RewriteInfo> rinfo_; @@ -441,7 +476,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::vector<MergeInfo> minfo_; /// Maintain info about nodes to rewrite - static std::vector<ContextInfo> cinfo_; + static std::vector<ContextInfo*> cinfo_; + + /// Context variables used in referencing rules + static ContextInfo biasaddgrad_matmul_context_; + static ContextInfo biasaddgrad_conv2dwithbias_context_; /// Hash table to maintain nodes visited in the graph. std::unordered_set<const Node*> visited_nodes_; @@ -464,19 +503,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Clear all visited nodes inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); } - // Is this a graph node that can accept variable number of inputs? - // Return true if yes, false otherwise. - // - // Concat, Split are vararg nodes. - inline bool IsVarArgNode(Node* n) { - if (n->type_string() == csinfo_.concat || - n->type_string() == csinfo_.concatv2 || - n->type_string() == csinfo_.split) { - return true; - } - return false; - } - // Is OpDef::ArgDef a list type? It could be N * T or list(type). // Refer to opdef.proto for details of list type. inline bool ArgIsList(const OpDef::ArgDef& arg) const { @@ -510,6 +536,39 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return string(kMklOpPrefix) + name; } + // Can op represented by node 'n' run on DEVICE_CPU? + // Op can run on CPU with MKL if the runtime assigned device or the + // user requested device contains device CPU, or both are empty. + bool CanOpRunOnCPUDevice(const Node* n) { + bool result = true; + string reason; + + // Substring that should be checked for in device name for CPU device. + const char* const kCPUDeviceSubStr = "cpu"; + + // If Op has been specifically assigned to a non-CPU device, then No. + if (!n->assigned_device_name().empty() && + !StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) { + result = false; + reason = "Op has been assigned a runtime device that is not CPU."; + } + + // If user has specifically assigned this op to a non-CPU device, then No. + if (!n->def().device().empty() && + !StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) { + result = false; + reason = "User has assigned a device that is not CPU."; + } + + if (result == false) { + VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " + << n->type_string() << ", reason: " << reason; + } + + // Otherwise Yes. + return result; + } + // Return a node that can be merged with input node 'n' // // @return pointer to the node if we can find such a @@ -538,13 +597,46 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Default rewrite rule to be used in scenario 1 for rewrite. // @return - true (since we want to always rewrite) - static bool AlwaysRewrite(const Node* n) { return true; } - // Rewrite rule that uses context-information for matching + static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) { + return true; + } + + // Check if we are performing pooling on depth or batch. If it is, then we + // do not rewrite MaxPool node to Mkl version. + // @return - true (if it is not a depth/batch wise pooling case); + // false otherwise. + static bool NonDepthBatchWisePoolRewrite(const Node* n, + const ContextInfo* c) { + CHECK_NOTNULL(n); + + string data_format_str; + TensorFormat data_format; + std::vector<int32> ksize, strides; + CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); + CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); + CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), + true); + CHECK_EQ(FormatFromString(data_format_str, &data_format), true); + + // Condition that specifies non-batch-wise and non-depth-wise pooling. + if (GetTensorDim(ksize, data_format, 'N') == 1 && + GetTensorDim(strides, data_format, 'N') == 1 && + GetTensorDim(ksize, data_format, 'C') == 1 && + GetTensorDim(strides, data_format, 'C') == 1) { + return true; + } + + return false; + } + + // Rewrite rule that uses context-information for matching, // used in scenario 2. // // @input - Node 'n' for which to search for matching context - // @return - true if matching context is found; false otherwise. - static bool ContextMatchRewrite(const Node* n); + // @input - The context 'c' under which to rewrite + // @return - true if we can rewrite node under context 'c'; + // false otherwise. + static bool ContextMatchRewrite(const Node* n, const ContextInfo* c); // Helper function that searches the matching contextinfo for the node. // Implements depth-first search in the data dependence graph for the @@ -598,6 +690,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // node that we are constructing. // // @input g - input graph, + // @input orig_node - Original node that we are rewriting // @input inputs - inputs to old node that we are using for constructing // new inputs, // @input input_idx - the index in the 'inputs' vector pointing to the @@ -608,11 +701,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output output_nodes - the list of new nodes creating Mkl tensors // // @return None - void GetNodesProducingMklTensorList( - std::unique_ptr<Graph>* g, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, - int* input_idx, int list_length, - std::vector<NodeBuilder::NodeOut>* output_nodes); + void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g, + Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, + int* input_idx, int list_length, + std::vector<NodeBuilder::NodeOut>* output_nodes); // Get a node that will feed an Mkl tensor to the new // node that we are constructing. The output node could be (1) 'n' @@ -620,6 +712,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // if 'n' is not an Mkl layer. // // @input g - input graph, + // @input orig_node - Original node that we are rewriting, // @input n - Node based on which we are creating Mkl node, // @input n_output_slot - the output slot of node 'n' // which is feeding to the node that we are constructing @@ -627,9 +720,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // @output mkl_node_output_slot - the slot number of mkl_node that // will feed the tensor // @return None - void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n, - int n_output_slot, Node** mkl_node, - int* mkl_node_output_slot); + void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, + Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot); // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are @@ -680,6 +772,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb); @@ -695,13 +788,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass { Node* orig_node); }; -std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_; +MklLayoutRewritePass::ContextInfo + MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; +MklLayoutRewritePass::ContextInfo + MklLayoutRewritePass::biasaddgrad_matmul_context_; +std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; -// We register Mkl rewrite pass for phase 1 in post rewrite group. +// We register Mkl rewrite pass for phase 1 in post partitioning group. // We register it here so that we get a complete picture of all users of Mkl // nodes. Do not change the ordering of the Mkl passes. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1, - MklLayoutRewritePass); +const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = + OptimizationPassRegistry::POST_PARTITIONING; +REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node @@ -737,27 +835,14 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList( while (list_length != 0) { CHECK_GT(list_length, 0); - CHECK_LE(*input_idx, inputs.size()); + CHECK_LT(*input_idx, inputs.size()); Node* n = inputs[*input_idx].first; int slot = inputs[*input_idx].second; - const OpDef::ArgDef& arg = n->op_def().output_arg(slot); - // If input node 'n' is producing a list/array output at output - // slot 'slot' then we need to find out the length of that list/array. - if (ArgIsList(arg)) { - int N = GetTensorListLength(arg, n); - CHECK_LE(N, list_length); - for (int j = 0; j < N; j++) { - output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); - } - (*input_idx)++; - list_length -= N; - } else { - // But if input node 'n' is just producing a single tensor at - // output slot 'slot' then we just add that single node. - output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); - (*input_idx)++; - list_length--; - } + // If input node 'n' is just producing a single tensor at + // output slot 'slot' then we just add that single node. + output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); + (*input_idx)++; + list_length--; } } @@ -775,20 +860,39 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); + + // If number of inputs to the original node is > 0, then we add + // control dependency between 1st input (index 0) of the original node and + // the dummy Mkl node. This is needed because control-flow ops such as Enter, + // Merge, etc, require frame_name of the dummy Mkl node to be same as the + // rewritten node. Adding control edge between 1st input of the original node + // and the dummy Mkl node ensures that the dummy node is in the same frame + // as the original node. Choosing 1st input is not necessary - any input of + // the original node is fine because all the inputs of a node are always in + // the same frame. + if (orig_node->num_inputs() > 0) { + Node* orig_input0 = nullptr; + TF_CHECK_OK(orig_node->input_node(0, + const_cast<const Node**>(&orig_input0))); + CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); + } + (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } void MklLayoutRewritePass::GetNodesProducingMklTensorList( std::unique_ptr<Graph>* g, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, - int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { + Node* orig_node, + const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, + int* input_idx, int list_length, + std::vector<NodeBuilder::NodeOut>* output_nodes) { CHECK_LT(*input_idx, inputs.size()); CHECK_GT(list_length, 0); CHECK_NOTNULL(output_nodes); @@ -796,38 +900,19 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( while (list_length != 0) { CHECK_GT(list_length, 0); - CHECK_LE(*input_idx, inputs.size()); + CHECK_LT(*input_idx, inputs.size()); Node* n = inputs[*input_idx].first; int slot = inputs[*input_idx].second; - const OpDef::ArgDef& arg = n->op_def().output_arg(slot); - // We need to check first if the input edge is going to carry a - // single tensor or a list of tensors. If it is a list of tensors, - // then we need to create list of Mkl dummy nodes. - if (ArgIsList(arg)) { - // If input node 'n' is producing a list/array output at output - // slot 'slot' then we need to find out the length of that list/array. - int N = GetTensorListLength(arg, n); - CHECK_LE(N, list_length); - Node* mkl_node = nullptr; - int mkl_node_output_slot = 0; - // If it is a list, then create a list of Mkl dummy nodes. - for (int j = 0; j < N; j++) { - GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back( - NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); - } - (*input_idx)++; - list_length -= N; - } else { - // If it is not a list, then create a single Mkl tensor node. - Node* mkl_node = nullptr; - int mkl_node_output_slot = 0; - GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); - output_nodes->push_back( - NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); - (*input_idx)++; - list_length--; - } + // If 'n' is producing a single tensor, then create a single Mkl tensor + // node. + Node* mkl_node = nullptr; + int mkl_node_output_slot = 0; + GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, + &mkl_node_output_slot); + output_nodes->push_back(NodeBuilder::NodeOut(mkl_node, + mkl_node_output_slot)); + (*input_idx)++; + list_length--; } } @@ -835,9 +920,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList( // node that we are constructing. An input node could be (1) 'n' // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor // if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor( - std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node, - int* mkl_node_output_slot) { +void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, + Node* orig_node, Node* n, + int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { CHECK_NOTNULL(n); CHECK_NOTNULL(mkl_node); CHECK_NOTNULL(mkl_node_output_slot); @@ -860,7 +945,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor( // to create a dummy node that will feed a dummy Mkl tensor to this node. // DummyMklTensor node has no input and generates only 1 output // (dummy Mkl tensor) as output slot number 0. - GetDummyMklTensorNode(g, mkl_node, n); + GetDummyMklTensorNode(g, mkl_node, orig_node); CHECK_NOTNULL(*mkl_node); *mkl_node_output_slot = 0; } @@ -926,16 +1011,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs( if (ArgIsList(arg)) { std::vector<NodeBuilder::NodeOut> new_node_inputs; int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N, - &new_node_inputs); + GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, + N, &new_node_inputs); nb->Input(new_node_inputs); nn_slot_idx++; } else { Node* mkl_node = nullptr; int mkl_node_output_slot = 0; - GetNodeProducingMklTensor(g, old_node_inputs[iidx].first, - old_node_inputs[iidx].second, &mkl_node, - &mkl_node_output_slot); + GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, + old_node_inputs[iidx].second, + &mkl_node, &mkl_node_output_slot); nb->Input(mkl_node, mkl_node_output_slot); iidx++; nn_slot_idx++; @@ -1020,13 +1105,30 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); + + // If number of inputs to the original node is > 0, then we add + // control dependency between 1st input (index 0) of the original node and + // the dummy Mkl node. This is needed because control-flow ops such as Enter, + // Merge, etc, require frame_name of the dummy Mkl node to be same as the + // rewritten node. Adding control edge between 1st input of the original node + // and the dummy Mkl node ensures that the dummy node is in the same frame + // as the original node. Choosing 1st input is not necessary - any input of + // the original node is fine because all the inputs of a node are always in + // the same frame. + if (orig_node->num_inputs() > 0) { + Node* orig_input0 = nullptr; + TF_CHECK_OK(orig_node->input_node(0, + const_cast<const Node**>(&orig_input0))); + CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); + } + (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } @@ -1179,6 +1281,16 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, nb->Attr("data_format", data_format); } +void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + // Add attributes to new node. + nb->Attr("T", T); +} + void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb) { DataType T; @@ -1235,6 +1347,19 @@ void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node, nb->Attr("T", T); } +void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + DataType Tshape; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("Tshape", Tshape); +} + void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb) { DataType T; @@ -1303,20 +1428,6 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, nb->Attr("is_training", is_training); } -void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - DataType Tshape; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("Tshape", Tshape); -} - ////////////////////////////////////////////////////////////////////////// // Helper functions related to node merge pass ////////////////////////////////////////////////////////////////////////// @@ -1353,8 +1464,9 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { continue; } + const int B_in = b->num_inputs(); gtl::InlinedVector<Node*, 4> b_control_edges; - gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in); + gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); FillInputs(b, &b_control_edges, &b_in); // Shouldn't merge if a and b have different control edges. @@ -1438,7 +1550,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, CHECK_EQ(succ->in_edges().size(), 2); Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3 int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0. - GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node + GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node // as BiasAdd does not have Mkl tensor as input. CHECK_NOTNULL(oper3_mkl); @@ -1483,9 +1595,38 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, // Set the Mkl layer label for this op. new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); + // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' + // node are already copied in BuildNode. We handle control edges now. + for (const Edge* e : pred->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + } + } + for (const Edge* e : succ->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + } + } + // Incoming edges are fixed, we will fix the outgoing edges now. + // First, we will fix outgoing control edges from 'pred' node. + // We don't need to handle outgoing data edges from 'pred' node + // because pred has only 1 output going to succ node (we enforced + // this check for merge already). + for (const Edge* e : pred->out_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + } + } + + // Second, we will fix outgoing control and data edges from 'succ' node. for (const Edge* e : succ->out_edges()) { - (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()); + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + } else { + CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(), + e->dst_input())); + } } // Copy device assigned to old node to new node. @@ -1550,18 +1691,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, "data_format or T attribute or devices of BiasAddGrad and " "Conv2D do not match. Will skip node rewrite optimization"); } + } else if (orig_node->type_string() == csinfo_.bias_add_grad && + ri->new_name == csinfo_.matmul) { + // When BiasAddGrad has MatMul in context, we do not do any rewrite + // and leave BiasAddGrad as it is. But we check for this condition + // when we check for node rewrite rule. So we should not even come + // here for MatMul. So we will fail now. + return Status( + error::Code::INVALID_ARGUMENT, + "No rewrite is required for BiasAddGrad for MatMul context."); } } // Get all inputs. - const int num = orig_node->in_edges().size(); - // Check the number of inputs against the user-specified value for non-vararg - // nodes. - if (!IsVarArgNode(orig_node)) { - CHECK_EQ(num, ri->num_ins); - } + const int num_inputs = orig_node->in_edges().size(); gtl::InlinedVector<Node*, 4> control_edges; - gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num); + gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); FillInputs(orig_node, &control_edges, &inputs); // Build new node. We use same name as original node, but change the op name. @@ -1596,8 +1741,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, TF_CHECK_OK(nb.Finalize(&**g, &new_node)); CHECK_NOTNULL(new_node); - // Incoming edges from 'orig_node' node to new 'new_node' node are already - // copied in BuildNode. Copy outgoing edges from 'orig_node' node to new + // Incoming data edges from 'orig_node' node to new 'new_node' node are + // already copied in BuildNode. We need to handle control edges now. + for (const Edge* e : orig_node->in_edges()) { + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + } + } + + // Copy outgoing edges from 'orig_node' node to new // 'new_node' node, since the output also follows same ordering among // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow // tensors appropriately. Specifically, nth output of the original node @@ -1605,15 +1757,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, // of the tensors. For the contiguous ordering of the tensors, it will be n. // GetTensorDataIndex provides this mapping function. for (const Edge* e : orig_node->out_edges()) { - // We need to handle control-edges by using their original slot number. - // Generally, -1 is reserved for control slot. - if (e->src_output() < 0) { - (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()); + if (e->IsControlEdge()) { + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); } else { - (*g)->AddEdge( - new_node, - GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), - e->dst(), e->dst_input()); + CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), + e->src()->num_outputs()), + e->dst(), e->dst_input())); } } @@ -1640,8 +1789,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, bool is_matching_cinfo_found = false; std::vector<const ContextInfo*> mci; for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { - if (n->type_string() == ci->node) { - mci.push_back(&*ci); + if (n->type_string() == (*ci)->node) { + mci.push_back(*ci); is_matching_cinfo_found = true; } } @@ -1701,9 +1850,10 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, return nullptr; } -bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) { +bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n, + const ContextInfo* c) { const Node* fwd_node = nullptr; - return SearchMatchingContext(n, &fwd_node) != nullptr; + return SearchMatchingContext(n, &fwd_node) == c; } const MklLayoutRewritePass::RewriteInfo* @@ -1719,18 +1869,29 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { return nullptr; } - if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) { - return nullptr; + // BiasAddGrad is not an Mkl layer, so we make an exception for it. + if (n->type_string() != csinfo_.bias_add_grad) { + if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) { + return nullptr; + } } // We support 2 types of node rewrites: - // 1. Rewriting BiasAddGrad depending on its context. + // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context. // 2. Rewriting an op to Mkl op always // We return true if any of these 2 conditions is met. // Find matching RewriteInfo and then check that rewrite rule applies. for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { - if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { + if (n->type_string().compare(ri->name) == 0 && + ri->rewrite_rule(n, ri->context)) { + // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context, + // then we just return directly. + if (n->type_string() == csinfo_.bias_add_grad && + ri->context->fwd == csinfo_.matmul && + ri->new_name == csinfo_.bias_add_grad) { + return nullptr; + } return &*ri; } } @@ -1753,7 +1914,8 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { GetReversePostOrder(**g, &order); // This will give us topological sort. for (Node* n : order) { - if (!n->IsOp()) { + // If node is not an op or it cannot run on CPU device, then skip. + if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { continue; } @@ -1801,18 +1963,31 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { return MklLayoutRewritePass().RunPass(g); } -Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { - if (options.graph == nullptr) { +Status MklLayoutRewritePass::Run( + const GraphOptimizationPassOptions& options) { + if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } - // Get the ownership of graph - std::unique_ptr<Graph>* g = std::move(options.graph); - - RunPass(g); - - // Return the ownership of graph back - options.graph->reset(g->release()); + auto process_graph = [&](std::unique_ptr<Graph>* g) { + // Get the ownership of a graph + std::unique_ptr<Graph>* ng = std::move(g); + RunPass(ng); + // Return the ownership of a graph back + g->reset(ng->release()); + }; + + if (kMklLayoutRewritePassGroup != + OptimizationPassRegistry::POST_PARTITIONING) { + // For any pre-partitioning phase, a graph is stored in options.graph. + process_graph(options.graph); + } else { + // For post partitioning phase, graphs are stored in + // options.partition_graphs. + for (auto& pg : *options.partition_graphs) { + process_graph(&pg.second); + } + } return Status::OK(); } |