diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 1275 |
1 files changed, 893 insertions, 382 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 309c4cd774..09b632a165 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -48,7 +48,7 @@ namespace tensorflow { // 1) Propagating Mkl layout as an additional output tensor // (we will loosely call a tensor that carries Mkl layout as Mkl tensor // henceforth.) from every Mkl supported NN layer. -// 2) Context-based rewrite: This is neded in order to optimize +// 2) Context-based rewrite: This is needed in order to optimize // gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and // MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into // Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad. @@ -63,12 +63,12 @@ namespace tensorflow { // P = BiasAdd(O, C) // // We merge them into Conv2DWithBias as: -// P = MklConv2DWithBias(A, A_m, B, B_m, C, C_m) +// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) // -// Meaning of A_m, B_m and C_m is explained in B.1. +// The meaning of A_m, B_m and C_m is explained in B.1. // // Merge rules: -// - Merge for Conv2D and BiasAdd happens only when output of Conv2D _only_ +// - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ // goes to BiasAdd. // - Also, the intersection of attributes of both the nodes must have same // values. @@ -76,7 +76,7 @@ namespace tensorflow { // // Example of B.1 : Rewriting nodes to Mkl nodes // --------------------------------------------- -// Consider Relu layer. Current definition of Relu layer looks like: +// Consider a Relu node. Current definition of Relu node looks like: // // O = Relu(A) // @@ -87,58 +87,59 @@ namespace tensorflow { // // O, O_m = MklRelu(A, A_m) // -// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here A input is -// same as A input of Relu; O output is same as O output of Relu. O_m is the +// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is +// same as input A of Relu; output O is same as output O of Relu. O_m is the // additional output tensor that will be set by MklRelu, and it represents // Mkl tensor corresponding to O -- in other words, O_m is some kind of // metadata for O. A_m is additional input of Relu, and it represents metadata // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives -// this metadata from previous layer (in the graph). +// this metadata from previous node in the graph. // -// When previous layer in the graph is Mkl layer, A_m will represent a valid -// Mkl tensor. But when previous Mkl layer is not an Mkl layer, then A_m -// represents a dummy Mkl tensor. +// When a previous node in the graph is an Mkl node, A_m will represent a valid +// Mkl tensor. But when a previous node is not an Mkl node, A_m will represent +// a dummy Mkl tensor. // // Rewriting rules: -// - Selection of an op for rewriting happens by registering an op with this -// pass. If an op is not registered, then it is not rewritten. +// - Selection of a node for rewriting happens by registering the op type of +// the node with the rewriting pass. If the op type is not registered, then +// all nodes of this op type will not be rewritten. // - Number of inputs after rewriting: -// Since for every input Tensorflow tensor, the rewritten layer gets Mkl -// tensor, rewritten op gets 2*N inputs, where N is the number of inputs -// for original op. +// Since for every input Tensorflow tensor, the rewritten node gets Mkl +// tensor(s), rewritten node gets 2*N inputs, where N is the number of +// inputs for the original node. // - Number of outputs after rewriting: -// Since for every output Tensorflow tensor, the rewritten layer generates -// Mkl tensor, rewritten op generates 2*N outputs, where N is the number -// of outputs of original op. +// Since for every output Tensorflow tensor, the rewritten node generates +// Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the +// number of outputs of the original node. // - Ordering of Tensorflow tensors and Mkl tensors: -// Since every op generates twice the number of inputs and outputs, one -// could imagine different ordering among Tensorflow tensors and Mkl -// tensors. E.g., let's assume an op 'Conv2D' takes (A, B) as input, then -// new op 'MklConv2D' can take (A, A_m, B, B_m) as input or it can also -// take (A, B, A_m, B_m) as input. Among N inputs one can get N! -// permutations. +// Since every rewritten node generates twice the number of inputs and +// outputs, one could imagine various orderings among Tensorflow tensors +// and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as +// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m +// in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m +// order. Among N inputs one can get N! permutations. // -// So the question is: which one do we follow? Currently, we follow an -// intuitive order where Mkl tensor follows a corresponding Tensorflow -// tensor immediately. In the context of above example, it will be: (A, -// A_m, B, B_m). We follow same ordering rule for output tensors. -// -// NOTE: Current rewriting approach rewrites an op to Mkl op without any -// conditions. But in the future, it may be possible to consider -// conditions such as input shapes and sizes to rewrite an op. +// So the question is: which order do we follow? We support 2 types of +// orderings: (1) interleaved, and (2) contiguous. Interleaved ordering +// follows an intuitive order where an Mkl tensor follows the +// corresponding Tensorflow tensor immediately. In the context of the +// above example, it will be: A, A_m, B, B_m. Note that the ordering rule +// applies to both the inputs and outputs. Contiguous ordering means +// all the Tensorflow tensors are contiguous followed by all the Mkl +// tensors. We use contiguous ordering as default. // // Graph rewrite algorithm: // Algorithm: Graph Rewrite -// Input: Graph G, Names of nodes to rewrite and their new nodes -// Output: Modified Graph G' if nodes are modified, G otherwise. +// Input: Graph G, Names of the nodes to rewrite and their new names +// Output: Modified Graph G' if the nodes are modified, G otherwise. // Start: -// N = Topological_Sort(G) // N is set of nodes in toposort order. +// N = Topological_Sort(G) // N is a set of nodes in toposort order. // foreach node n in N // do -// if (Is_MKL_Layer(n)) // Can this layer accept Mkl layout as input. +// if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. // then // E = set of <incoming edge and its src_output slot> of n -// E' = {} // new set of edges for rewritten node +// E' = {} // a new set of edges for rewritten node // foreach <e,s> in E // do // E' U {<e,s>} // First copy edge which generates Tensorflow @@ -146,42 +147,44 @@ namespace tensorflow { // m = Source node of edge e // if Is_Rewritten(m) // Did we rewrite this node in this pass? // then -// E' U {<m,s+1>} // If yes, then m will generate Mkl tensor -// // as output. +// E' U {<m,s+1>} // If yes, then m will generate an Mkl +// // tensor as an additional output. // else -// d = Generate_Dummy_Mkl_Tensor() // If not, generate dummy +// d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy // // Mkl tensor. -// E' U {<d,0>} // Dummy Mkl tensor has only 1 output slot. +// E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. // fi // done // n' = Build_New_Node(G,new_name,E') -// Mark_Rewritten(n') // Mark new node as being rewritten. +// Mark_Rewritten(n') // Mark the new node as being rewritten. // fi // done // // Explanation: -// For graph rewrite, we visit nodes of the graph in the topological -// sort order. With this ordering, we visit nodes in top-to-bottom -// fashion. We need this order because while visiting a node we want -// all of its input nodes (parents) visited (and rewritten if -// applicable). This is because if we need to rewrite a current node +// For graph rewrite, we visit nodes of the input graph in the +// topological sort order. With this ordering, we visit nodes in the +// top-to-bottom fashion. We need this order because while visiting a +// node we want that all of its input nodes are visited and rewritten if +// applicable. This is because if we need to rewrite a given node // then all of its input nodes need to be fixed (in other words they -// cannot be removed later.) +// cannot be deleted later.) // -// While visiting each node, we first check if it is Mkl layer. If -// it is, then we rewrite that node after constructing new inputs to -// the node. If it is not Mkl layer, then we do not rewrite the node. +// While visiting a node, we first check if the op type of the node is +// an Mkl op. If it is, then we rewrite that node after constructing +// new inputs to the node. If the op type of the node is not Mkl op, +// then we do not rewrite that node. // // Handling workspace propagation for certain ops: // // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require -// passing of workspace from their corresponding forward ops. But -// TensorFlow does not have a notion of workspace and as a result -// does not allow producing additional outputs from these forward ops. -// For these ops, we need to add an additional edge between forward -// ops and their corresponding backward ops, and this edge carries -// workspace tensor value and another edge carries Mkl tensor for -// workspace tensor. +// passing of a workspace from their respective forward ops. Workspace +// tensors provide memory for storing results of intermediate operations +// which are helpful in backward propagation. TensorFlow does not have +// a notion of a workspace and as a result does not allow producing +// additional outputs from these forward ops. For these ops, we need +// to add 2 extra edges between forward ops and their corresponding +// backward ops - the first extra edge carries a workspace tensor and +// the second one carries an Mkl tensor for the workspace tensor. // // Example: // @@ -190,59 +193,61 @@ namespace tensorflow { // A = MaxPool(T) // B = MaxPoolGrad(X, A, Y) // -// We will transform this graph to propagate workspace as: +// We will transform this graph to propagate the workspace as: +// (with the contiguous ordering) // -// A, A_m, W, W_m = MklMaxPool(T, T_m) -// B, B_m = MklMaxPoolGrad(X, X_m, A, A_m, Y, Y_m, W, W_m) +// A, W, A_m, W_m = MklMaxPool(T, T_m) +// B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) // -// Here W is the workspace tensor. Transformed tensors with name -// suffix _m are Mkl tensors and this transformation has been done +// Here W is the workspace tensor. Transformed tensor names with the +// suffix _m are Mkl tensors, and this transformation has been done // using the algorithm discussed earlier. The transformation for -// workspace only adds extra outputs (W, W_m) for forward op and -// connects them to corresponding backward ops. +// workspace propagation only adds extra outputs (W, W_m) for a forward +// op and connects them to the corresponding backward ops. // // Terms: // // Forward op name = name of the op in the forward pass -// where workspace originates (MaxPool in this example) +// where a workspace tensor originates (MaxPool in this example) // Backward op name = name of the op in the backward pass that receives -// workspace from forward op (MaxPoolGrad in the example) -// Slot = Number of the output or input slot that will be -// used by the workspace (2 for MklMaxPool as W is 3rd -// output of MaxPool (0 is 1st); 6 for MklMaxPoolGrad) +// a workspace tensor from the forward op (MaxPoolGrad in the example) +// Slot = Position of the output or input slot that will be +// used by the workspace tensor (1 for MklMaxPool as W is the 2nd +// output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) // // Question: // -// How do we associate backward op to forward op? There can be more -// than one op with exact same name. +// How do we associate a backward op to a forward op? There can be more +// than one op with the exact same name. // -// In this example we associate MaxPoolGrad with MaxPool. But there +// In this example, we associate MaxPoolGrad with MaxPool. But there // could be more than one MaxPool ops. To solve this problem, we look -// for _direct_ edge between forward op and backward op (tensor A is -// flowing along this edge in the example.) +// for _direct_ edge between a forward op and a backward op (tensor A is +// flowing along this edge in the example). // -// How do we transform forward and backward op when there is no direct -// edge between them? In such case, we generate dummy tensors as +// How do we transform forward and backward ops when there is no direct +// edge between them? In such a case, we generate dummy tensors for // workspace tensors. For the example, transformation of MaxPool will -// be exactly same --- it is just that MaxPool won't generate any -// workspace tensor. For MaxPoolGrad, transformation will also be same, -// but instead of connecting W and W_m with outputs of MaxPool, we will -// produce dummy tensors for them, and we will set workspace_enabled -// attribute to false. +// be exactly same as it would be when there is a direct edge between +// the forward and the backward op --- it is just that MaxPool won't +// generate any workspace tensor. For MaxPoolGrad, the transformation +// will also be same, but instead of connecting W and W_m with the +// outputs of MaxPool, we will produce dummy tensors for them, and we +// will set workspace_enabled attribute to false. // // Example of B.2 : Context-based node rewrite // ------------------------------------------- // Consider BiasAddGrad op as: // -// O = MklConv2D(A, A_m, B, B_m, C, C_m) +// O = _MklConv2D(A, B, C, A_m, B_m, C_m) // P = BiasAddGrad(O) // -// Then we rewrite is as: +// Then we rewrite it as: // // P = Conv2DWithBiasBackpropBias(O, O_m) // -// 'Distance' between input of BiasAddGrad and MklConv2D in terms of hops is -// the context matching depth. If MklConv2DWithBias is not within the context +// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is +// the context matching depth. If _MklConv2DWithBias is not within the context // matching depth, then we do not rewrite BiasAddGrad. // How many hops do we search for matching node in the backward dataflow graph? @@ -255,53 +260,85 @@ static size_t kNodeMergeContextMaxDepth = 10; class MklLayoutRewritePass : public GraphOptimizationPass { public: MklLayoutRewritePass() { + // NOTE: names are alphabetically sorted. + csinfo_.avg_pool = "AvgPool"; + csinfo_.avg_pool_grad = "AvgPoolGrad"; + csinfo_.bias_add = "BiasAdd"; + csinfo_.bias_add_grad = "BiasAddGrad"; + csinfo_.concat = "Concat"; + csinfo_.concatv2 = "ConcatV2"; csinfo_.conv2d = "Conv2D"; - csinfo_.mklconv2d = "MklConv2D"; - csinfo_.mklconv2dwithbias = "MklConv2DWithBias"; - csinfo_.mklconv2dwithbiasbackpropbias = "MklConv2DWithBiasBackpropBias"; - csinfo_.biasadd = "BiasAdd"; + csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; + csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; + csinfo_.fused_batch_norm = "FusedBatchNorm"; + csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; + csinfo_.lrn = "LRN"; + csinfo_.lrn_grad = "LRNGrad"; csinfo_.matmul = "MatMul"; - csinfo_.biasaddgrad = "BiasAddGrad"; + csinfo_.max_pool = "MaxPool"; + csinfo_.max_pool_grad = "MaxPoolGrad"; + csinfo_.mkl_conv2d = "_MklConv2D"; + csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; + csinfo_.mkl_conv2d_with_bias_backprop_bias = + "_MklConv2DWithBiasBackpropBias"; csinfo_.relu = "Relu"; - csinfo_.relugrad = "ReluGrad"; - csinfo_.maxpool = "MaxPool"; - csinfo_.maxpoolgrad = "MaxPoolGrad"; - csinfo_.avgpool = "AvgPool"; - csinfo_.avgpoolgrad = "AvgPoolGrad"; - csinfo_.conv2dgradinput = "Conv2DBackpropInput"; - csinfo_.conv2dgradfilter = "Conv2DBackpropFilter"; - - rinfo_.push_back( - {csinfo_.conv2d, csinfo_.mklconv2d, 2, CopyAttrsConv2D, AlwaysRewrite}); - rinfo_.push_back({csinfo_.conv2dgradfilter, - GetMklOpName(csinfo_.conv2dgradfilter), 3, + csinfo_.reshape = "Reshape"; + csinfo_.relu_grad = "ReluGrad"; + csinfo_.split = "Split"; + + // NOTE: names are alphabetically sorted. + rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1, + CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.avg_pool_grad, + GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling, + AlwaysRewrite}); + rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0, + CopyAttrsConcat, AlwaysRewrite}); + rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0, + CopyAttrsConcatV2, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2, + CopyAttrsConv2D, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv2d_grad_filter, + GetMklOpName(csinfo_.conv2d_grad_filter), 3, CopyAttrsConv2D, AlwaysRewrite}); - rinfo_.push_back({csinfo_.conv2dgradinput, - GetMklOpName(csinfo_.conv2dgradinput), 3, CopyAttrsConv2D, + rinfo_.push_back({csinfo_.conv2d_grad_input, + GetMklOpName(csinfo_.conv2d_grad_input), 3, + CopyAttrsConv2D, AlwaysRewrite}); + rinfo_.push_back({csinfo_.fused_batch_norm, + GetMklOpName(csinfo_.fused_batch_norm), 5, + CopyAttrsFusedBatchNorm, AlwaysRewrite}); + rinfo_.push_back({csinfo_.fused_batch_norm_grad, + GetMklOpName(csinfo_.fused_batch_norm_grad), 5, + CopyAttrsFusedBatchNorm, AlwaysRewrite}); + rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN, + AlwaysRewrite}); + rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3, + CopyAttrsLRN, AlwaysRewrite}); + rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1, + CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.max_pool_grad, + GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling, AlwaysRewrite}); rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1, CopyAttrsRelu, AlwaysRewrite}); - rinfo_.push_back({csinfo_.maxpool, GetMklOpName(csinfo_.maxpool), 1, - CopyAttrsPooling, AlwaysRewrite}); - rinfo_.push_back({csinfo_.maxpoolgrad, GetMklOpName(csinfo_.maxpoolgrad), 3, - CopyAttrsPooling, AlwaysRewrite}); - rinfo_.push_back({csinfo_.avgpool, GetMklOpName(csinfo_.avgpool), 1, - CopyAttrsPooling, AlwaysRewrite}); - rinfo_.push_back({csinfo_.avgpoolgrad, GetMklOpName(csinfo_.avgpoolgrad), 2, - CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2, + CopyAttrsReshape, AlwaysRewrite}); + + // TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet. // Add info about which ops to add workspace edge to and the slots. - wsinfo_.push_back({csinfo_.maxpool, csinfo_.maxpoolgrad, 0, 1, 2, 6}); + wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); + wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); // Add a rule for merging nodes - minfo_.push_back( - {csinfo_.mklconv2d, csinfo_.biasadd, 0, csinfo_.mklconv2dwithbias}); + minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0, + csinfo_.mkl_conv2d_with_bias}); // We use maxhop of 10 based on empirical observations. Also, these are // maxhops in backward data-flow graph. Since input of forward nodes // (Conv2D) directly goes to backward nodes, we do not expect the // hop-distance would be more than few nodes. - cinfo_.push_back({csinfo_.biasaddgrad, csinfo_.mklconv2dwithbias, + cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, kNodeMergeContextMaxDepth}); } @@ -318,73 +355,80 @@ class MklLayoutRewritePass : public GraphOptimizationPass { bool RunPass(std::unique_ptr<Graph>* g); private: - /// Structure to specify name of original op, its new name after rewrite, - /// the number of inputs to the original op, and the function to be used - /// to copy attributes for the op + /// Structure to specify the name of an original node, its new name after + /// rewrite, the number of inputs to the original node, the function to + /// be used to copy attributes for the op, and the rule (if any) which + /// must hold for rewriting the node typedef struct { - string name; // Original name of the op in the graph - string newname; // New name of op in the graph - int numins; // Number of inputs to the original op - // Function handler to copy attributes from old node to new node. - std::function<void(const Node*, NodeBuilder*)> copyattrs; - std::function<bool(const Node*)> rewriterule; // Rule under which to - // rewrite this node. + string name; // Original name of op of the node in the graph + string new_name; // New name of the op of the node in the graph + int num_ins; // The number of inputs to the original op type + // A function handler to copy attributes from an old node to a new node. + std::function<void(const Node*, NodeBuilder*)> copy_attrs; + std::function<bool(const Node*)> rewrite_rule; // A rule under which to + // rewrite this node. } RewriteInfo; - /// Structure to specify forward op, backward op, and the slot numbers - /// in forward and backward op where we will add workspace edge. + /// Structure to specify a forward op, a backward op, and the slot numbers + /// in the forward and backward ops where we will add a workspace edge. typedef struct { - string fwdop; // Name of the forward op in the graph - string bwdop; // Name of the backward op in the graph - int fwdslot; // Output slot in the forward op node where actual - // output tensor resides - int bwdslot; // Input slot in the backward op node where actual - // input tensor resides - int wsfwdslot; // Output slot in the forward op node where workspace - // edge is added - int wsbwdslot; // Input slot in the backward op node where workspace - // edge is added + string fwd_op; // Name of a forward op in the graph + string bwd_op; // Name of a backward op in the graph + int fwd_slot; // Output slot in the forward op node where actual + // output tensor resides + int bwd_slot; // Input slot in the backward op node where actual + // input tensor resides + int ws_fwd_slot; // Output slot in the forward op node where workspace + // edge is added + int ws_bwd_slot; // Input slot in the backward op node where workspace + // edge is added } WorkSpaceInfo; /// Structure to specify information used in node merge typedef struct { - string pred; // Predecessor node string - string succ; // Successor node string - int op; // What operand no the predecessor node corresponds - // to successor node? - string newnode; // Name of the node after merge + string pred; // Predecessor node string + string succ; // Successor node string + int op; // The operand no the predecessor node corresponds + // to the successor node + string new_node; // Name of the node after merge } MergeInfo; - /// Structure to specify the context information used in node rewrite rule + /// Structure to specify the context information used in a node rewrite rule typedef struct { - string node; // Name of the node to be rewritten - string fwd; // Node name in forward pass that this node - // corresponds to - size_t maxhop; // Maximum number of hops the fwd is located - // from this node. If fwd is farther than maxhop - // then we do not rewrite the node. + string node; // Name of the node to be rewritten + string fwd; // Name of the node in the forward pass that this node + // corresponds to + size_t max_hop; // Maximum number of hops the fwd is located + // from this node. If the fwd is farther than max_hop + // then we do not rewrite the node. } ContextInfo; /// Structure to store all constant strings + /// NOTE: names are alphabetically sorted. struct { - string relu; - string relugrad; - // Conv ops + string avg_pool; + string avg_pool_grad; + string bias_add; + string bias_add_grad; + string concat; + string concatv2; string conv2d; - string mklconv2d; - string conv2dgradinput; - string conv2dgradfilter; - string mklconv2dwithbias; - string mklconv2dwithbiasbackpropbias; - // Pooling ops - string maxpool; - string maxpoolgrad; - string avgpool; - string avgpoolgrad; - // Others - string biasadd; + string conv2d_grad_input; + string conv2d_grad_filter; + string fused_batch_norm; + string fused_batch_norm_grad; + string lrn; + string lrn_grad; string matmul; - string biasaddgrad; + string max_pool; + string max_pool_grad; + string mkl_conv2d; + string mkl_conv2d_with_bias; + string mkl_conv2d_with_bias_backprop_bias; + string relu; + string relu_grad; + string split; + string reshape; } csinfo_; /// Maintain info about nodes to rewrite @@ -393,7 +437,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Maintain info about nodes to add workspace edge std::vector<WorkSpaceInfo> wsinfo_; - /// Maintain info to be merged + /// Maintain info about nodes to be merged std::vector<MergeInfo> minfo_; /// Maintain info about nodes to rewrite @@ -403,7 +447,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { std::unordered_set<const Node*> visited_nodes_; private: - // Predicate to check if we rewrote node 'n' + // Check if we rewrote node 'n' // // If we rewrote the node, then the rewritten node will produce // Mkl tensor as output. If we did not rewrite the node, then @@ -420,12 +464,49 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Clear all visited nodes inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); } + // Is this a graph node that can accept variable number of inputs? + // Return true if yes, false otherwise. + // + // Concat, Split are vararg nodes. + inline bool IsVarArgNode(Node* n) { + if (n->type_string() == csinfo_.concat || + n->type_string() == csinfo_.concatv2 || + n->type_string() == csinfo_.split) { + return true; + } + return false; + } + + // Is OpDef::ArgDef a list type? It could be N * T or list(type). + // Refer to opdef.proto for details of list type. + inline bool ArgIsList(const OpDef::ArgDef& arg) const { + return !arg.type_list_attr().empty() || !arg.number_attr().empty(); + } + + // Get length of a list in 'n' if 'arg' is of list type. Refer to + // description of ArgIsList for definition of list type. + inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { + CHECK_EQ(ArgIsList(arg), true); + int N = 0; + const string attr_name = !arg.type_list_attr().empty() + ? arg.type_list_attr() + : arg.number_attr(); + if (!arg.type_list_attr().empty()) { + std::vector<DataType> value; + TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); + N = value.size(); + } else { + TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); + } + return N; + } + // Get the name of Mkl op from original TensorFlow op // We prefix 'Mkl' to the original op to get Mkl op. // TODO(nhasabni) We should move this to mkl_util.h. inline string GetMklOpName(const string& name) const { // Prefix that we add to Tensorflow op name to construct Mkl op name. - const char* const kMklOpPrefix = "Mkl"; + const char* const kMklOpPrefix = "_Mkl"; return string(kMklOpPrefix) + name; } @@ -440,7 +521,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // // Input nodes succ and pred may be deleted if the call to // this function is successful. Attempt to use the pointers - // after the call to function may result is undefined behaviors. + // after the call to function may result in undefined behaviors. // // @input g - input graph, succ - successor node, pred - predecessor node // @return Status::OK(), if merging is successful and supported. @@ -470,13 +551,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // gradient op in the backward direction. // // @input n - Node (gradient op) whose contextinfo is to be searched, - // fwdn - pointer to node from the forward pass that this node - // belongs to. fwdn cannot be NULL. + // fwd_node - pointer to node from the forward pass that this node + // belongs to. fwd_node cannot be NULL. // @return Matching contextinfo in case a match is found; null otherwise. - // Also updates *fwdn with pointer to forward node that this context - // matches. + // Also updates *fwd_node with pointer to forward node that this + // context matches. static const ContextInfo* SearchMatchingContext(const Node* n, - const Node** fwdn); + const Node** fwd_node); // Rewrites input node to a new node specified by its matching rewrite info. // @@ -494,46 +575,132 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Otherwise, it is not updated. Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); + // Get nodes that will feed a list of TF tensors to the new + // node that we are constructing. + // + // @input g - input graph, + // @input inputs - inputs to old node that we are using for constructing + // new inputs, + // @input input_idx - the index in the 'inputs' vector pointing to the + // current input that we have processed so far + // @output input_idx - index will be incremented by the number of nodes + // from 'inputs' that are processed + // @input list_length - The expected length of list of TF tensors + // @output output_nodes - the list of new nodes creating TF tensors + // + // @return None + void GetNodesProducingTFTensorList( + const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, + int* input_idx, int list_length, + std::vector<NodeBuilder::NodeOut>* output_nodes); + + // Get nodes that will feed a list of Mkl tensors to the new + // node that we are constructing. + // + // @input g - input graph, + // @input inputs - inputs to old node that we are using for constructing + // new inputs, + // @input input_idx - the index in the 'inputs' vector pointing to the + // current input that we have processed so far + // @output input_idx - index will be incremented by the number of nodes + // from 'inputs' that are processed + // @input list_length - The expected length of list of Mkl tensors + // @output output_nodes - the list of new nodes creating Mkl tensors + // + // @return None + void GetNodesProducingMklTensorList( + std::unique_ptr<Graph>* g, + const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, + int* input_idx, int list_length, + std::vector<NodeBuilder::NodeOut>* output_nodes); + + // Get a node that will feed an Mkl tensor to the new + // node that we are constructing. The output node could be (1) 'n' + // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor + // if 'n' is not an Mkl layer. + // + // @input g - input graph, + // @input n - Node based on which we are creating Mkl node, + // @input n_output_slot - the output slot of node 'n' + // which is feeding to the node that we are constructing + // @output mkl_node - the new node that will feed Mkl tensor + // @output mkl_node_output_slot - the slot number of mkl_node that + // will feed the tensor + // @return None + void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n, + int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot); + // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' - // in graph 'g'. Original node is input in 'orign'. + // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are + // set up in contiguous fashion. 'workspace_tensors' carry graph nodes + // producing workspace edges if 'are_workspace_tensors_available' is true. + // Otherwise, 'workspace_tensors' is empty vector. // - // For details, refer to 'Number of inputs after rewriting' section in the + // For details, refer to 'Ordering of inputs after rewriting' section in the // documentation above. // // Returns Status::OK() if setting up inputs is successful, otherwise // returns appropriate status code. + int SetUpContiguousInputs( + std::unique_ptr<Graph>* g, + const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, + NodeBuilder* nb, Node* old_node, + std::vector<NodeBuilder::NodeOut>* workspace_tensors, + bool are_workspace_tensors_available); + + // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' + // in graph 'g'. Original node is input in 'orig_node'. + // + // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' + // section in the documentation above. + // + // Returns Status::OK() if setting up inputs is successful, otherwise + // returns appropriate status code. Status SetUpInputs(std::unique_ptr<Graph>* g, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, - NodeBuilder* nb, Node* orign); - - // Add workspace edge on the input or output side of Node 'orign' by using - // NodeBuilder 'nb' for the new node provided. If 'orign' does not dictate - // adding workspace edge then do not add it. - void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orign, - NodeBuilder* nb); + NodeBuilder* nb, Node* orig_node); + + // Add workspace edge on the input or output side of Node 'orig_node' by using + // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate + // adding workspace edge then do not add it. Workspace Tensorflow and Mkl + // tensors, if they need to be added, will be set into these tensors. + // If we set workspace tensors, then are_ws_tensors_added should be true. + void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node, + NodeBuilder* nb, + std::vector<NodeBuilder::NodeOut>* ws_tensors, + bool* are_ws_tensors_added); // Functions specific to operators to copy attributes // We need operator-specific function to copy attributes because the framework // does not provide any generic function for it. - static void CopyAttrsConv2D(const Node* orign, NodeBuilder* nb); - static void CopyAttrsBiasAddGrad(const Node* orign, NodeBuilder* nb); - static void CopyAttrsPooling(const Node* orign, NodeBuilder* nb); - static void CopyAttrsRelu(const Node* orign, NodeBuilder* nb); + // NOTE: names are alphabetically sorted. + static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, - // using node for original node 'orign' and return it in '*out'. + // using node for original node 'orig_node' and return it in '*out'. // TODO(nhasabni) We should move this to mkl_util.h void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, - Node* orign); + Node* orig_node); void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, - Node* orign); + Node* orig_node); }; std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_; -// We register Mkl rewrite pass for phase 1 in pre-placement group. -// Do not change the ordering of the Mkl passes. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1, +// We register Mkl rewrite pass for phase 1 in post rewrite group. +// We register it here so that we get a complete picture of all users of Mkl +// nodes. Do not change the ordering of the Mkl passes. +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1, MklLayoutRewritePass); ////////////////////////////////////////////////////////////////////////// @@ -543,7 +710,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1, static void FillInputs(const Node* n, gtl::InlinedVector<Node*, 4>* control_edges, gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { - DCHECK_EQ(in->size(), n->num_inputs()); control_edges->clear(); for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) { @@ -561,9 +727,43 @@ static void FillInputs(const Node* n, } } +void MklLayoutRewritePass::GetNodesProducingTFTensorList( + const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, + int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { + CHECK_LT(*input_idx, inputs.size()); + CHECK_GT(list_length, 0); + CHECK_NOTNULL(output_nodes); + output_nodes->reserve(list_length); + + while (list_length != 0) { + CHECK_GT(list_length, 0); + CHECK_LE(*input_idx, inputs.size()); + Node* n = inputs[*input_idx].first; + int slot = inputs[*input_idx].second; + const OpDef::ArgDef& arg = n->op_def().output_arg(slot); + // If input node 'n' is producing a list/array output at output + // slot 'slot' then we need to find out the length of that list/array. + if (ArgIsList(arg)) { + int N = GetTensorListLength(arg, n); + CHECK_LE(N, list_length); + for (int j = 0; j < N; j++) { + output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); + } + (*input_idx)++; + list_length -= N; + } else { + // But if input node 'n' is just producing a single tensor at + // output slot 'slot' then we just add that single node. + output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); + (*input_idx)++; + list_length--; + } + } +} + // TODO(nhasabni) We should move this to mkl_util.h. void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, - Node** out, Node* orign) { + Node** out, Node* orig_node) { // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent // dummy Mkl tensor. 8 = 2*size_t. const DataType dt = DataTypeToEnum<uint8>::v(); @@ -574,63 +774,228 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, 8); TensorShape dummy_shape({8}); dummy_shape.AsProto(proto.mutable_tensor_shape()); - TF_CHECK_OK( - NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orign->def().device()) // We place this node on same - // device as device of original - // node. - .Finalize(&**g, out)); - (*out)->set_assigned_device_name(orign->assigned_device_name()); + TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // the same device as the + // device of the original + // node. + .Finalize(&**g, out)); + (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } -Status MklLayoutRewritePass::SetUpInputs( +void MklLayoutRewritePass::GetNodesProducingMklTensorList( + std::unique_ptr<Graph>* g, + const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, + int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { + CHECK_LT(*input_idx, inputs.size()); + CHECK_GT(list_length, 0); + CHECK_NOTNULL(output_nodes); + output_nodes->reserve(list_length); + + while (list_length != 0) { + CHECK_GT(list_length, 0); + CHECK_LE(*input_idx, inputs.size()); + Node* n = inputs[*input_idx].first; + int slot = inputs[*input_idx].second; + const OpDef::ArgDef& arg = n->op_def().output_arg(slot); + // We need to check first if the input edge is going to carry a + // single tensor or a list of tensors. If it is a list of tensors, + // then we need to create list of Mkl dummy nodes. + if (ArgIsList(arg)) { + // If input node 'n' is producing a list/array output at output + // slot 'slot' then we need to find out the length of that list/array. + int N = GetTensorListLength(arg, n); + CHECK_LE(N, list_length); + Node* mkl_node = nullptr; + int mkl_node_output_slot = 0; + // If it is a list, then create a list of Mkl dummy nodes. + for (int j = 0; j < N; j++) { + GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); + } + (*input_idx)++; + list_length -= N; + } else { + // If it is not a list, then create a single Mkl tensor node. + Node* mkl_node = nullptr; + int mkl_node_output_slot = 0; + GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot); + output_nodes->push_back( + NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); + (*input_idx)++; + list_length--; + } + } +} + +// Get an input node that will feed Mkl tensor to the new +// node that we are constructing. An input node could be (1) 'n' +// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor +// if 'n' is not an Mkl layer. +void MklLayoutRewritePass::GetNodeProducingMklTensor( + std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node, + int* mkl_node_output_slot) { + CHECK_NOTNULL(n); + CHECK_NOTNULL(mkl_node); + CHECK_NOTNULL(mkl_node_output_slot); + if (IsRewrittenNode(n)) { + // If we have visited this node and rewritten it, then it will generate + // an edge that will receive Mkl tensor from a node. + // First, let's assert that this op is Mkl layer. + DataType T; + TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T)); + // If this op has been rewritten, then its name must have been same as + // Mkl op. + CHECK_EQ(mkl_op_registry::IsMklOp(n->type_string(), T), true); + // output slot number for Mkl tensor would be N+slot number of TensorFlow + // tensor, where N is total number of TensorFlow tensors. + *mkl_node = n; + *mkl_node_output_slot = + GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); + } else { + // If we have not visited the node and rewritten it, then we need + // to create a dummy node that will feed a dummy Mkl tensor to this node. + // DummyMklTensor node has no input and generates only 1 output + // (dummy Mkl tensor) as output slot number 0. + GetDummyMklTensorNode(g, mkl_node, n); + CHECK_NOTNULL(*mkl_node); + *mkl_node_output_slot = 0; + } +} + +int MklLayoutRewritePass::SetUpContiguousInputs( std::unique_ptr<Graph>* g, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, NodeBuilder* nb, - Node* orign) { - std::vector<NodeBuilder::NodeOut> new_inputs; - - // 1. Let's setup inputs for the new node. - for (int i = 0; i < inputs.size(); i++) { - Node* n = inputs[i].first; - // First let's copy original TF tensor input as it is. - new_inputs.push_back(NodeBuilder::NodeOut(n, inputs[i].second)); - - // Second, let's add edge to propagate Mkl tensors from input Mkl layers, - // or generate a dummy Mkl tensor representing not-mkl-tensor case. - if (IsRewrittenNode(n)) { - // If we have visited this node and rewritten it, then it will generate - // an edge that will receive Mkl tensor from a node. - // First, let's assert that this op is Mkl layer. - DataType T; - TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T)); - // If this op has been rewritten, then its name must have been same as - // Mkl op. - CHECK_EQ(mkl_layer_registry::IsMklLayer(n->type_string(), T), true); - // src slot number for Mkl tensor would be the one next to TF tensor - // slot number. - new_inputs.push_back(NodeBuilder::NodeOut(n, inputs[i].second + 1)); + const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, + NodeBuilder* nb, Node* old_node, + std::vector<NodeBuilder::NodeOut>* workspace_tensors, + bool are_workspace_tensors_available) { + CHECK_NOTNULL(workspace_tensors); + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + + // Number of input slots to original op + // Input slots are represented by .Input() calls in REGISTER_OP. + int old_node_input_slots = old_node->op_def().input_arg_size(); + // Actual number of inputs can be greater than or equal to number + // of Input slots because inputs of type list could be unfolded. + CHECK_GE(old_node_inputs.size(), old_node_input_slots); + int nn_slot_idx = 0; // slot index for inputs of new node + + // Let's copy all inputs (TF tensors) of original node to new node. + int iidx = 0; + for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { + // An input slot could be a single tensor or a list. We need + // to handle this case accordingly. + CHECK_LT(iidx, old_node_inputs.size()); + const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); + if (ArgIsList(arg)) { + std::vector<NodeBuilder::NodeOut> new_node_inputs; + int N = GetTensorListLength(arg, old_node); + GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, + &new_node_inputs); + nb->Input(new_node_inputs); + nn_slot_idx++; } else { - // If we have not visited the node and rewritten it, then we need - // to create a dummy node that will feed a non-Mkl tensor to this node. - // DummyMklTensor node has no input and generates only 1 output - // (dummy Mkl tensor) as output slot number 0. - Node* dmt = nullptr; - GetDummyMklTensorNode(g, &dmt, orign); - CHECK_NOTNULL(dmt); - new_inputs.push_back(NodeBuilder::NodeOut(dmt, 0)); + nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); + iidx++; + nn_slot_idx++; } } - // The total number of inputs to new node _must_ be 2 times the number - // of inputs to the original node: N original Tensorflow tensors and - // N for Mkl tensors corresponding to each Tensorflow tensors. - CHECK_EQ(new_inputs.size(), inputs.size() * 2); + // If workspace tensors are available for this op and we are using + // contiguous ordering then we need to add Tensorflow tensor for + // workspace here because Tensorflow tensor for workspace is the + // last tensor in the list of Tensorflow tensors. + if (are_workspace_tensors_available) { + CHECK_EQ(workspace_tensors->size(), 2); + // Tensorflow tensor + nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); + nn_slot_idx++; + } + + // Let's now setup all Mkl inputs to new node. + // Number of Mkl inputs must be same as number of TF inputs. + iidx = 0; + for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { + // An input slot could be a single tensor or a list. We need + // to handle this case accordingly. + CHECK_LT(iidx, old_node_inputs.size()); + const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); + if (ArgIsList(arg)) { + std::vector<NodeBuilder::NodeOut> new_node_inputs; + int N = GetTensorListLength(arg, old_node); + GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N, + &new_node_inputs); + nb->Input(new_node_inputs); + nn_slot_idx++; + } else { + Node* mkl_node = nullptr; + int mkl_node_output_slot = 0; + GetNodeProducingMklTensor(g, old_node_inputs[iidx].first, + old_node_inputs[iidx].second, &mkl_node, + &mkl_node_output_slot); + nb->Input(mkl_node, mkl_node_output_slot); + iidx++; + nn_slot_idx++; + } + } + + // If workspace tensors are available for this op and we are using + // contiguous ordering then we need to add Mkl tensor for + // workspace here because Mkl tensor for workspace is the + // last tensor in the list of Mkl tensors. + if (are_workspace_tensors_available) { + CHECK_EQ(workspace_tensors->size(), 2); + // Mkl tensor + nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); + nn_slot_idx++; + } + + return nn_slot_idx; +} + +Status MklLayoutRewritePass::SetUpInputs( + std::unique_ptr<Graph>* g, + const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, + NodeBuilder* nb, Node* old_node) { + // Let's check if we need to add workspace tensors for this node. + // We add workspace edge only for MaxPool, LRN and BatchNorm. + std::vector<NodeBuilder::NodeOut> workspace_tensors; + bool are_workspace_tensors_available = false; + AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, + &are_workspace_tensors_available); + + int new_node_input_slots = 0; + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + // TODO(nhasabni): implement this function just for same of completion. + // We do not use interleaved ordering right now. + return Status( + error::Code::UNIMPLEMENTED, + "Interleaved ordering of tensors is currently not supported."); + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + new_node_input_slots = SetUpContiguousInputs( + g, old_node_inputs, nb, old_node, &workspace_tensors, + are_workspace_tensors_available); + } - // 2. Let's add the new inputs. - for (auto ni : new_inputs) { - nb->Input(ni.node, ni.index); + // Sanity check + int old_node_input_slots = old_node->op_def().input_arg_size(); + if (!are_workspace_tensors_available) { + // If we are not adding workspace tensors for this op, then the total + // number of input slots to the new node _must_ be 2 times the number + // of input slots to the original node: N original Tensorflow tensors and + // N for Mkl tensors corresponding to each Tensorflow tensors. + CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); + } else { + // If we are adding workspace tensors for this op, then the total + // The total number of input slots to new node _must_ be 2 times the number + // of input slots to the original node: N original Tensorflow tensors and + // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 + // (for workspace Tensorflow tensor and workspace Mkl tensor). + CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); } return Status::OK(); @@ -642,7 +1007,7 @@ Status MklLayoutRewritePass::SetUpInputs( // TODO(nhasabni) We should move this to mkl_util.h. void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( - std::unique_ptr<Graph>* g, Node** out, Node* orign) { + std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { // We use a tensor of shape {1} and value 0 to represent // dummy float tensor. We need this as a dummy workspace tensor. // Workspace tensor has type float. @@ -654,39 +1019,42 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( 4); TensorShape dummy_shape({1}); dummy_shape.AsProto(proto.mutable_tensor_shape()); - TF_CHECK_OK( - NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orign->def().device()) // We place this node on same - // device as device of original - // node. - .Finalize(&**g, out)); - (*out)->set_assigned_device_name(orign->assigned_device_name()); + TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orig_node->def().device()) // We place this node on + // same the device as the + // device of the original + // node. + .Finalize(&**g, out)); + (*out)->set_assigned_device_name(orig_node->assigned_device_name()); } -void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, - Node* orign, - NodeBuilder* nb) { - bool workspace_edge_added = false; +void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( + std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb, + std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { + bool workspace_edge_added = false; // Default initializer + CHECK_NOTNULL(are_ws_tensors_added); + *are_ws_tensors_added = false; // Default initializer + DataType T; - TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); for (auto ws : wsinfo_) { - if (orign->type_string() == ws.fwdop && - mkl_layer_registry::IsMklLayer(GetMklOpName(orign->type_string()), T)) { + if (orig_node->type_string() == ws.fwd_op && + mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), T)) { // If this op is a fwd op, then we need to check if there is an - // edge from this node's fwdslot to bwdop's bwdslot. If there is + // edge from this node's fwd_slot to bwdop's bwd_slot. If there is // an edge, then we just add an attribute on this node for setting // workspace_passed to true. We don't add actual workspace edge // in this node. Actual workspace edge gets added in the backward // op for this node. - for (const Edge* e : orign->out_edges()) { - if (e->src_output() == ws.fwdslot && - e->dst()->type_string() == ws.bwdop && - e->dst_input() == ws.bwdslot) { + for (const Edge* e : orig_node->out_edges()) { + if (e->src_output() == ws.fwd_slot && + e->dst()->type_string() == ws.bwd_op && + e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " - << orign->type_string(); + << orig_node->type_string(); workspace_edge_added = true; // We found the edge that we were looking for, so break. break; @@ -698,34 +1066,40 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, // node. nb->Attr("workspace_enabled", false); } - } else if (orign->type_string() == ws.bwdop && - mkl_layer_registry::IsMklLayer( - GetMklOpName(orign->type_string()), T)) { + } else if (orig_node->type_string() == ws.bwd_op && + mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), + T)) { // If this op is a bwd op, then we need to add workspace edge and // it's Mkl tensor edge between its corresponding fwd op and this - // op. Corresponding fwd op is specified in 'fwdop' field of - // workspace info. fwdslot and bwdslot in workspace info specify + // op. Corresponding fwd op is specified in 'fwd_op' field of + // workspace info. fwd_slot and bwd_slot in workspace info specify // an edge between which slots connect forward and backward op. // Once all these criteria match, we add a workspace edge between - // wsfwdslot and wsbwdslot. It's corresponding Mkl tensor is added - // in wsfwdslot+1 and wsbwdslot+1. - for (const Edge* e : orign->in_edges()) { - if (e->src_output() == ws.fwdslot && + // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is + // determined by interleaved/contiguous ordering. Function + // DataIndexToMetaDataIndex tells us the location of Mkl tensor + // from the location of the Tensorflow tensor. + for (const Edge* e : orig_node->in_edges()) { + if (e->src_output() == ws.fwd_slot && // We would have rewritten the forward op, so we need to use // GetMklOpName call to get its Mkl name. - e->src()->type_string() == GetMklOpName(ws.fwdop) && - e->dst_input() == ws.bwdslot) { + e->src()->type_string() == GetMklOpName(ws.fwd_op) && + e->dst_input() == ws.bwd_slot) { nb->Attr("workspace_enabled", true); + CHECK_NOTNULL(ws_tensors); // Add workspace edge between fwd op and bwd op. - nb->Input(e->src(), ws.wsfwdslot); + ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); // Add Mkl tensor edge for workspace edge between fwd op and bwd op. - nb->Input(e->src(), ws.wsfwdslot + 1); + ws_tensors->push_back(NodeBuilder::NodeOut( + e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, + e->src()->num_outputs()))); + *are_ws_tensors_added = true; // In terms of input ordering, we add these calls to add Input // here because workspace edge (and its Mkl tensor) is the last // edge in the fwdop and bwdop. So all inputs before workspace // tensor have been added by SetUpInputs function. VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " - << orign->type_string(); + << orig_node->type_string(); workspace_edge_added = true; // We found the edge that we were looking for, so break. break; @@ -740,15 +1114,18 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, nb->Attr("workspace_enabled", false); Node* dmt_ws = nullptr; // Dummy tensor for workspace Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace - GetDummyWorkspaceTensorNode(g, &dmt_ws, orign); - GetDummyMklTensorNode(g, &dmt_mkl_ws, orign); + GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); + GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); CHECK_NOTNULL(dmt_ws); CHECK_NOTNULL(dmt_mkl_ws); - nb->Input(dmt_ws, 0); // We add dummy tensor as workspace tensor. - nb->Input(dmt_mkl_ws, 0); // We add dummy tensor as Mkl - // tensor for workspace tensor. + CHECK_NOTNULL(ws_tensors); + // We add dummy tensor as workspace tensor. + ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); + // We add dummy tensor as Mkl tensor for workspace tensor. + ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); + *are_ws_tensors_added = true; VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " - << orign->type_string(); + << orig_node->type_string(); } } else { // If this node does not match any workspace info, then we do not @@ -761,7 +1138,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, // Op-specific functions to copy attributes from old node to new node ////////////////////////////////////////////////////////////////////////// -void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) { +void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, + NodeBuilder* nb) { DataType T; string data_format; string padding; @@ -769,11 +1147,12 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) { bool use_cudnn_on_gpu; // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); + TF_CHECK_OK( + GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); // Add attributes to new node. nb->Attr("T", T); @@ -783,16 +1162,16 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) { nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); } -void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign, +void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb) { DataType T; string data_format; std::vector<int32> strides; // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); // Add attributes to new node. nb->Attr("T", T); @@ -800,7 +1179,30 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign, nb->Attr("data_format", data_format); } -void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign, +void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + int depth_radius; + float bias; + float alpha; + float beta; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("depth_radius", depth_radius); + nb->Attr("bias", bias); + nb->Attr("alpha", alpha); + nb->Attr("beta", beta); +} + +void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb) { DataType T; string data_format; @@ -808,11 +1210,11 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign, std::vector<int32> ksize, strides; // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "ksize", &ksize)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); // Add attributes to new node. nb->Attr("T", T); @@ -822,14 +1224,97 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign, nb->Attr("data_format", data_format); } -void MklLayoutRewritePass::CopyAttrsRelu(const Node* orign, NodeBuilder* nb) { +void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + + // Add attributes to new node. + nb->Attr("T", T); +} + +void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + string data_format; + int num_split; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("num_split", num_split); + nb->Attr("data_format", data_format); +} + +void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + int N; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("N", N); +} + +void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + int N; + DataType tidx; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("N", N); + nb->Attr("Tidx", tidx); +} + +void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + float epsilon; + string data_format; + bool is_training; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("epsilon", epsilon); + nb->Attr("data_format", data_format); + nb->Attr("is_training", is_training); +} + +void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, + NodeBuilder* nb) { DataType T; + DataType Tshape; // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); // Add attributes to new node. nb->Attr("T", T); + nb->Attr("Tshape", Tshape); } ////////////////////////////////////////////////////////////////////////// @@ -889,8 +1374,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, CHECK_NOTNULL(succ); CHECK_NOTNULL(pred); - if (succ->type_string() == csinfo_.biasadd && - pred->type_string() == csinfo_.mklconv2d) { + if (succ->type_string() == csinfo_.bias_add && + pred->type_string() == csinfo_.mkl_conv2d) { // 1. Get all attributes from input nodes. DataType T_pred, T_succ; string padding; @@ -947,7 +1432,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, // 2. Get inputs from both the nodes. // Find the 2 inputs from the conv and the bias from the add Bias. // Get operand 0, 1 of conv2D and their Mkl tensors. - CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs. + CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs. // Get operand 1 of add_bias // BiasAdd must have 2 inputs: Conv, bias CHECK_EQ(succ->in_edges().size(), 2); @@ -960,13 +1445,29 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, // We will use the node name of BiasAdd as the name of new node // Build new node. We use same name as original node, but change the op // name. - NodeBuilder nb(succ->name(), csinfo_.mklconv2dwithbias); - nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D - nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1 - nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D - nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 - nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd - nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd + NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias); + if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { + nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D + // pred_in[1] will be Mkl tensor for In1 if we follow interleaved + // ordering, and it will be 2nd Tensorflow tensor for Conv2D if + // we follow contiguous ordering. + nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1 + nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D + nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 + nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd + nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd + } else { + CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); + nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D + // pred_in[1] will be Mkl tensor for In1 if we follow interleaved + // ordering, and it will be 2nd Tensorflow tensor for Conv2D if + // we follow contiguous ordering. + nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D + nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd + nb.Input(pred_in[2].first, pred_in[2].second); // Mkl for In1 of Conv2D + nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 of Conv2D + nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd + } // Copy attributes from Conv2D to Conv2DWithBias. CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); @@ -975,30 +1476,30 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, nb.Device(succ->def().device()); // Create node. - Node* newn; - nb.Finalize(&**g, &newn); - CHECK_NOTNULL(newn); + Node* new_node; + nb.Finalize(&**g, &new_node); + CHECK_NOTNULL(new_node); // Set the Mkl layer label for this op. - newn->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel); + new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); // Incoming edges are fixed, we will fix the outgoing edges now. for (const Edge* e : succ->out_edges()) { - (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input()); + (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()); } // Copy device assigned to old node to new node. // It's ok to use pred or succ as we have enforced a check that // both have same device assigned. - newn->set_assigned_device_name(pred->assigned_device_name()); + new_node->set_assigned_device_name(pred->assigned_device_name()); VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() << ", and node: " << succ->DebugString() - << ", into node:" << newn->DebugString(); + << ", into node:" << new_node->DebugString(); (*g)->RemoveNode(succ); (*g)->RemoveNode(pred); - MarkRewrittenNode(newn); + MarkRewrittenNode(new_node); return Status::OK(); } @@ -1011,35 +1512,39 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, // Helper functions for node rewrite ////////////////////////////////////////////////////////////////////////// -Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign, +Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, + Node* orig_node, const RewriteInfo* ri) { CHECK_NOTNULL(ri); - CHECK_NOTNULL(orign); + CHECK_NOTNULL(orig_node); - VLOG(1) << "MklLayoutRewritePass: Original node:" << orign->DebugString(); + VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); // Check if this is scenario 2 (context-based rewrite). // Get the matching ContextInfo if it is. - const Node* fwdn = nullptr; + const Node* fwd_node = nullptr; const ContextInfo* ci = nullptr; bool is_context_based_rewrite = false; - if ((ci = SearchMatchingContext(orign, &fwdn)) != nullptr) { - CHECK_NOTNULL(fwdn); + if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) { + CHECK_NOTNULL(fwd_node); is_context_based_rewrite = true; // Sanity checks for context-based rewrite (if any) - if (orign->type_string() == csinfo_.biasaddgrad && - ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) { + if (orig_node->type_string() == csinfo_.bias_add_grad && + ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { DataType orig_T, ctx_T; string orig_data_format, ctx_data_format; - TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &orig_T)); - TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &orig_data_format)); - TF_CHECK_OK(GetNodeAttr(fwdn->def(), "T", &ctx_T)); - TF_CHECK_OK(GetNodeAttr(fwdn->def(), "data_format", &ctx_data_format)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T)); + TF_CHECK_OK( + GetNodeAttr(orig_node->def(), "data_format", &orig_data_format)); + TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T)); + TF_CHECK_OK( + GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format)); if (orig_data_format != ctx_data_format || orig_T != ctx_T || - orign->assigned_device_name() != fwdn->assigned_device_name() || - orign->def().device() != fwdn->def().device()) { + orig_node->assigned_device_name() != + fwd_node->assigned_device_name() || + orig_node->def().device() != fwd_node->def().device()) { return Status( error::Code::INVALID_ARGUMENT, "data_format or T attribute or devices of BiasAddGrad and " @@ -1049,18 +1554,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign, } // Get all inputs. - const int num = orign->num_inputs(); - CHECK_EQ(num, ri->numins); + const int num = orig_node->in_edges().size(); + // Check the number of inputs against the user-specified value for non-vararg + // nodes. + if (!IsVarArgNode(orig_node)) { + CHECK_EQ(num, ri->num_ins); + } gtl::InlinedVector<Node*, 4> control_edges; gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num); - FillInputs(orign, &control_edges, &inputs); + FillInputs(orig_node, &control_edges, &inputs); // Build new node. We use same name as original node, but change the op name. - NodeBuilder nb(orign->name().c_str(), ri->newname.c_str()); + NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); // Copy user-specified device assigned to original node to new node. - nb.Device(orign->def().device()); + nb.Device(orig_node->def().device()); // Set up new inputs to the rewritten node. - Status s = SetUpInputs(g, inputs, &nb, orign); + Status s = SetUpInputs(g, inputs, &nb, orig_node); if (s != Status::OK()) { return s; } @@ -1068,62 +1577,63 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign, // Copy attributes from original node to new node (for scenario 1). // For context-based rewrite, we use context to copy the attributes. if (is_context_based_rewrite) { - if (orign->type_string() == csinfo_.biasaddgrad && - ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) { - CHECK_NOTNULL(fwdn); - ri->copyattrs(fwdn, &nb); + if (orig_node->type_string() == csinfo_.bias_add_grad && + ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { + CHECK_NOTNULL(fwd_node); + ri->copy_attrs(fwd_node, &nb); } else { return Status(error::Code::UNIMPLEMENTED, "Unimplemented case for node rewrite optimization."); } } else { - ri->copyattrs(const_cast<const Node*>(orign), &nb); + ri->copy_attrs(const_cast<const Node*>(orig_node), &nb); } // Set the Mkl layer label for this op. - nb.Attr("_kernel", mkl_layer_registry::kMklLayerLabel); - - // Add workspace edge to this node if needed. - // We add workspace edge only for MaxPool, LRN and BatchNorm. - AddWorkSpaceEdgeIfNeeded(g, orign, &nb); + nb.Attr("_kernel", mkl_op_registry::kMklOpLabel); // Finalize graph and get new node. - Node* newn = nullptr; - TF_CHECK_OK(nb.Finalize(&**g, &newn)); - CHECK_NOTNULL(newn); - - // Incoming edges from 'orign' node to new 'newn' node are already copied - // in BuildNode. Copy outgoing edges from 'orign' node to new 'newn' node. - // Since the output also follows same ordering among Tensorflow tensors and - // Mkl tensors. We need to connect Tensorflow tensors appropriately. - // Specifically, nth output of original node will become 2*nth output of - // Mkl node. GetTensorDataIndex provides this mapping function. - for (const Edge* e : orign->out_edges()) { + Node* new_node = nullptr; + TF_CHECK_OK(nb.Finalize(&**g, &new_node)); + CHECK_NOTNULL(new_node); + + // Incoming edges from 'orig_node' node to new 'new_node' node are already + // copied in BuildNode. Copy outgoing edges from 'orig_node' node to new + // 'new_node' node, since the output also follows same ordering among + // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow + // tensors appropriately. Specifically, nth output of the original node + // will become 2*nth output of the Mkl node for the interleaved ordering + // of the tensors. For the contiguous ordering of the tensors, it will be n. + // GetTensorDataIndex provides this mapping function. + for (const Edge* e : orig_node->out_edges()) { // We need to handle control-edges by using their original slot number. // Generally, -1 is reserved for control slot. if (e->src_output() < 0) { - (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input()); + (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()); } else { - (*g)->AddEdge(newn, GetTensorDataIndex(e->src_output()), e->dst(), - e->dst_input()); + (*g)->AddEdge( + new_node, + GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), + e->dst(), e->dst_input()); } } // Copy the runtime device assigned from original code to new node. - newn->set_assigned_device_name(orign->assigned_device_name()); + new_node->set_assigned_device_name(orig_node->assigned_device_name()); // Delete original node and mark new node as rewritten. - (*g)->RemoveNode(orign); - MarkRewrittenNode(newn); + (*g)->RemoveNode(orig_node); + MarkRewrittenNode(new_node); - VLOG(1) << "MklLayoutRewritePass: New node:" << newn->DebugString(); + VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); return Status::OK(); } const MklLayoutRewritePass::ContextInfo* -MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) { +MklLayoutRewritePass::SearchMatchingContext(const Node* n, + const Node** fwd_node) { CHECK_NOTNULL(n); - CHECK_NOTNULL(fwdn); - *fwdn = nullptr; + CHECK_NOTNULL(fwd_node); + *fwd_node = nullptr; // Search for matching contextinfo based on node name. // There could be more than one matching contextinfos. @@ -1171,7 +1681,7 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) { // If we find a match, we return immediately. for (const ContextInfo* ci : mci) { if (curr_node->type_string() == ci->fwd) { - *fwdn = curr_node; + *fwd_node = curr_node; return ci; } } @@ -1192,8 +1702,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) { } bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) { - const Node* fwdn = nullptr; - return SearchMatchingContext(n, &fwdn) != nullptr; + const Node* fwd_node = nullptr; + return SearchMatchingContext(n, &fwd_node) != nullptr; } const MklLayoutRewritePass::RewriteInfo* @@ -1208,7 +1718,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { if (!GetNodeAttr(n->def(), "T", &T).ok()) { return nullptr; } - if (!mkl_layer_registry::IsMklLayer(GetMklOpName(n->type_string()), T)) { + + if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) { return nullptr; } @@ -1219,7 +1730,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { // Find matching RewriteInfo and then check that rewrite rule applies. for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { - if (n->type_string().compare(ri->name) == 0 && ri->rewriterule(n)) { + if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { return &*ri; } } |