diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 1086 |
1 files changed, 922 insertions, 164 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 9a2e4bcfa0..309c4cd774 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -15,13 +15,15 @@ limitations under the License. #ifdef INTEL_MKL +#include <algorithm> #include <functional> #include <memory> +#include <queue> +#include <set> #include <string> #include <unordered_set> #include <utility> #include <vector> - #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/node_def_util.h" @@ -39,68 +41,91 @@ limitations under the License. namespace tensorflow { -// This pass implements rewriting of graph for propagating Mkl -// layout as an additional output tensor (we will loosely call a -// tensor that carries Mkl layout as Mkl tensor henceforth.) -// from every Mkl supported NN layer. +// This pass implements rewriting of graph to support following scenarios: +// (A) Merging nodes in the graph +// (B) Rewriting a node in the graph to a new node +// Rewrite happens under following 2 scenarios: +// 1) Propagating Mkl layout as an additional output tensor +// (we will loosely call a tensor that carries Mkl layout as Mkl tensor +// henceforth.) from every Mkl supported NN layer. +// 2) Context-based rewrite: This is neded in order to optimize +// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and +// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into +// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad. +// This is context-specific optimization, where the context is the +// forward operator that the BiasAddGrad corresponds to. +// +// Example of A : Merging nodes in the graph +// ----------------------------------------- +// Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: +// +// O = Conv2D(A, B) +// P = BiasAdd(O, C) +// +// We merge them into Conv2DWithBias as: +// P = MklConv2DWithBias(A, A_m, B, B_m, C, C_m) // -// As a example, consider Relu layer. Current definition of Relu -// layer looks like: +// Meaning of A_m, B_m and C_m is explained in B.1. +// +// Merge rules: +// - Merge for Conv2D and BiasAdd happens only when output of Conv2D _only_ +// goes to BiasAdd. +// - Also, the intersection of attributes of both the nodes must have same +// values. +// - Both the nodes must have been assigned to same device (if any). +// +// Example of B.1 : Rewriting nodes to Mkl nodes +// --------------------------------------------- +// Consider Relu layer. Current definition of Relu layer looks like: // // O = Relu(A) // // Relu has 1 input (A), and 1 output (O). // -// This rewrite pass will generate a new graph node for Relu -// (new node is called MklRelu) as: +// This rewrite pass will generate a new graph node for Relu (new node is +// called MklRelu) as: // // O, O_m = MklRelu(A, A_m) // -// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). -// Here A input is same as A input of Relu; O output is same -// as O output of Relu. O_m is the additional output tensor -// that will be set by MklRelu, and it represents Mkl tensor -// corresponding to O -- in other words, O_m is some kind of -// metadata for O. A_m is additional input of Relu, and it -// represents metadata for A - as O_m is metadata for O, A_m -// is metadata for A. MklRelu receives this metadata from -// previous layer (in the graph). +// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here A input is +// same as A input of Relu; O output is same as O output of Relu. O_m is the +// additional output tensor that will be set by MklRelu, and it represents +// Mkl tensor corresponding to O -- in other words, O_m is some kind of +// metadata for O. A_m is additional input of Relu, and it represents metadata +// for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives +// this metadata from previous layer (in the graph). // -// When previous layer in the graph is Mkl layer, A_m will -// represent a valid Mkl tensor. But when previous Mkl layer -// is not an Mkl layer, then A_m represents a dummy Mkl tensor. +// When previous layer in the graph is Mkl layer, A_m will represent a valid +// Mkl tensor. But when previous Mkl layer is not an Mkl layer, then A_m +// represents a dummy Mkl tensor. // // Rewriting rules: -// - Selection of an op for rewriting happens by registering -// an op with this pass. If an op is not registered, then -// it is not rewritten. +// - Selection of an op for rewriting happens by registering an op with this +// pass. If an op is not registered, then it is not rewritten. // - Number of inputs after rewriting: -// Since for every input Tensorflow tensor, the rewritten -// layer gets Mkl tensor, rewritten op gets 2*N inputs, -// where N is the number of inputs for original op. +// Since for every input Tensorflow tensor, the rewritten layer gets Mkl +// tensor, rewritten op gets 2*N inputs, where N is the number of inputs +// for original op. // - Number of outputs after rewriting: -// Since for every output Tensorflow tensor, the rewritten -// layer generates Mkl tensor, rewritten op generates 2*N -// outputs, where N is the number of outputs of original op. +// Since for every output Tensorflow tensor, the rewritten layer generates +// Mkl tensor, rewritten op generates 2*N outputs, where N is the number +// of outputs of original op. // - Ordering of Tensorflow tensors and Mkl tensors: -// Since every op generates twice the number of inputs and -// outputs, one could imagine different ordering among -// Tensorflow tensors and Mkl tensors. E.g., let's assume -// an op 'Conv2D' takes (A, B) as input, then new op -// 'MklConv2D' can take (A, A_m, B, B_m) as input or it -// can also take (A, B, A_m, B_m) as input. Among N inputs -// one can get N! permutations. -// -// So the question is: which one do we follow? Currently, -// we follow an intuitive order where Mkl tensor follows a -// corresponding Tensorflow tensor immediately. In the -// context of above example, it will be: (A, A_m, B, B_m). -// We follow same ordering rule for output tensors. -// -// NOTE: Current rewriting approach rewrites an op to Mkl op without -// any conditions. But in the future, it may be possible to -// consider conditions such as input shapes and sizes to rewrite -// an op. +// Since every op generates twice the number of inputs and outputs, one +// could imagine different ordering among Tensorflow tensors and Mkl +// tensors. E.g., let's assume an op 'Conv2D' takes (A, B) as input, then +// new op 'MklConv2D' can take (A, A_m, B, B_m) as input or it can also +// take (A, B, A_m, B_m) as input. Among N inputs one can get N! +// permutations. +// +// So the question is: which one do we follow? Currently, we follow an +// intuitive order where Mkl tensor follows a corresponding Tensorflow +// tensor immediately. In the context of above example, it will be: (A, +// A_m, B, B_m). We follow same ordering rule for output tensors. +// +// NOTE: Current rewriting approach rewrites an op to Mkl op without any +// conditions. But in the future, it may be possible to consider +// conditions such as input shapes and sizes to rewrite an op. // // Graph rewrite algorithm: // Algorithm: Graph Rewrite @@ -147,13 +172,137 @@ namespace tensorflow { // it is, then we rewrite that node after constructing new inputs to // the node. If it is not Mkl layer, then we do not rewrite the node. // +// Handling workspace propagation for certain ops: +// +// Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require +// passing of workspace from their corresponding forward ops. But +// TensorFlow does not have a notion of workspace and as a result +// does not allow producing additional outputs from these forward ops. +// For these ops, we need to add an additional edge between forward +// ops and their corresponding backward ops, and this edge carries +// workspace tensor value and another edge carries Mkl tensor for +// workspace tensor. +// +// Example: +// +// Typical graph for MaxPool and its gradient looks like: +// +// A = MaxPool(T) +// B = MaxPoolGrad(X, A, Y) +// +// We will transform this graph to propagate workspace as: +// +// A, A_m, W, W_m = MklMaxPool(T, T_m) +// B, B_m = MklMaxPoolGrad(X, X_m, A, A_m, Y, Y_m, W, W_m) +// +// Here W is the workspace tensor. Transformed tensors with name +// suffix _m are Mkl tensors and this transformation has been done +// using the algorithm discussed earlier. The transformation for +// workspace only adds extra outputs (W, W_m) for forward op and +// connects them to corresponding backward ops. +// +// Terms: +// +// Forward op name = name of the op in the forward pass +// where workspace originates (MaxPool in this example) +// Backward op name = name of the op in the backward pass that receives +// workspace from forward op (MaxPoolGrad in the example) +// Slot = Number of the output or input slot that will be +// used by the workspace (2 for MklMaxPool as W is 3rd +// output of MaxPool (0 is 1st); 6 for MklMaxPoolGrad) +// +// Question: +// +// How do we associate backward op to forward op? There can be more +// than one op with exact same name. +// +// In this example we associate MaxPoolGrad with MaxPool. But there +// could be more than one MaxPool ops. To solve this problem, we look +// for _direct_ edge between forward op and backward op (tensor A is +// flowing along this edge in the example.) +// +// How do we transform forward and backward op when there is no direct +// edge between them? In such case, we generate dummy tensors as +// workspace tensors. For the example, transformation of MaxPool will +// be exactly same --- it is just that MaxPool won't generate any +// workspace tensor. For MaxPoolGrad, transformation will also be same, +// but instead of connecting W and W_m with outputs of MaxPool, we will +// produce dummy tensors for them, and we will set workspace_enabled +// attribute to false. +// +// Example of B.2 : Context-based node rewrite +// ------------------------------------------- +// Consider BiasAddGrad op as: +// +// O = MklConv2D(A, A_m, B, B_m, C, C_m) +// P = BiasAddGrad(O) +// +// Then we rewrite is as: +// +// P = Conv2DWithBiasBackpropBias(O, O_m) +// +// 'Distance' between input of BiasAddGrad and MklConv2D in terms of hops is +// the context matching depth. If MklConv2DWithBias is not within the context +// matching depth, then we do not rewrite BiasAddGrad. + +// How many hops do we search for matching node in the backward dataflow graph? +// We use maxhop of 10 based on empirical observations. Also, these are +// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D) +// directly goes to backward nodes, we do not expect the hop-distance +// would be more than few nodes. +static size_t kNodeMergeContextMaxDepth = 10; + class MklLayoutRewritePass : public GraphOptimizationPass { public: MklLayoutRewritePass() { csinfo_.conv2d = "Conv2D"; - - ninfo_.push_back( - {csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2, CopyAttrsConv2D}); + csinfo_.mklconv2d = "MklConv2D"; + csinfo_.mklconv2dwithbias = "MklConv2DWithBias"; + csinfo_.mklconv2dwithbiasbackpropbias = "MklConv2DWithBiasBackpropBias"; + csinfo_.biasadd = "BiasAdd"; + csinfo_.matmul = "MatMul"; + csinfo_.biasaddgrad = "BiasAddGrad"; + csinfo_.relu = "Relu"; + csinfo_.relugrad = "ReluGrad"; + csinfo_.maxpool = "MaxPool"; + csinfo_.maxpoolgrad = "MaxPoolGrad"; + csinfo_.avgpool = "AvgPool"; + csinfo_.avgpoolgrad = "AvgPoolGrad"; + csinfo_.conv2dgradinput = "Conv2DBackpropInput"; + csinfo_.conv2dgradfilter = "Conv2DBackpropFilter"; + + rinfo_.push_back( + {csinfo_.conv2d, csinfo_.mklconv2d, 2, CopyAttrsConv2D, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv2dgradfilter, + GetMklOpName(csinfo_.conv2dgradfilter), 3, + CopyAttrsConv2D, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv2dgradinput, + GetMklOpName(csinfo_.conv2dgradinput), 3, CopyAttrsConv2D, + AlwaysRewrite}); + rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1, + CopyAttrsRelu, AlwaysRewrite}); + rinfo_.push_back({csinfo_.maxpool, GetMklOpName(csinfo_.maxpool), 1, + CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.maxpoolgrad, GetMklOpName(csinfo_.maxpoolgrad), 3, + CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.avgpool, GetMklOpName(csinfo_.avgpool), 1, + CopyAttrsPooling, AlwaysRewrite}); + rinfo_.push_back({csinfo_.avgpoolgrad, GetMklOpName(csinfo_.avgpoolgrad), 2, + CopyAttrsPooling, AlwaysRewrite}); + + // Add info about which ops to add workspace edge to and the slots. + wsinfo_.push_back({csinfo_.maxpool, csinfo_.maxpoolgrad, 0, 1, 2, 6}); + + // Add a rule for merging nodes + minfo_.push_back( + {csinfo_.mklconv2d, csinfo_.biasadd, 0, csinfo_.mklconv2dwithbias}); + + // We use maxhop of 10 based on empirical observations. Also, these are + // maxhops in backward data-flow graph. Since input of forward nodes + // (Conv2D) directly goes to backward nodes, we do not expect the + // hop-distance would be more than few nodes. + cinfo_.push_back({csinfo_.biasaddgrad, csinfo_.mklconv2dwithbias, + kNodeMergeContextMaxDepth}); } // Standard interface to run pass @@ -176,20 +325,79 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string name; // Original name of the op in the graph string newname; // New name of op in the graph int numins; // Number of inputs to the original op - std::function<void(Node*, NodeBuilder*)> - copyattrs; // Function handler - // to copy attributes from old node to new node. - } NodesInfo; + // Function handler to copy attributes from old node to new node. + std::function<void(const Node*, NodeBuilder*)> copyattrs; + std::function<bool(const Node*)> rewriterule; // Rule under which to + // rewrite this node. + } RewriteInfo; + + /// Structure to specify forward op, backward op, and the slot numbers + /// in forward and backward op where we will add workspace edge. + typedef struct { + string fwdop; // Name of the forward op in the graph + string bwdop; // Name of the backward op in the graph + int fwdslot; // Output slot in the forward op node where actual + // output tensor resides + int bwdslot; // Input slot in the backward op node where actual + // input tensor resides + int wsfwdslot; // Output slot in the forward op node where workspace + // edge is added + int wsbwdslot; // Input slot in the backward op node where workspace + // edge is added + } WorkSpaceInfo; + + /// Structure to specify information used in node merge + typedef struct { + string pred; // Predecessor node string + string succ; // Successor node string + int op; // What operand no the predecessor node corresponds + // to successor node? + string newnode; // Name of the node after merge + } MergeInfo; + + /// Structure to specify the context information used in node rewrite rule + typedef struct { + string node; // Name of the node to be rewritten + string fwd; // Node name in forward pass that this node + // corresponds to + size_t maxhop; // Maximum number of hops the fwd is located + // from this node. If fwd is farther than maxhop + // then we do not rewrite the node. + } ContextInfo; /// Structure to store all constant strings struct { string relu; string relugrad; + // Conv ops string conv2d; + string mklconv2d; + string conv2dgradinput; + string conv2dgradfilter; + string mklconv2dwithbias; + string mklconv2dwithbiasbackpropbias; + // Pooling ops + string maxpool; + string maxpoolgrad; + string avgpool; + string avgpoolgrad; + // Others + string biasadd; + string matmul; + string biasaddgrad; } csinfo_; /// Maintain info about nodes to rewrite - std::vector<NodesInfo> ninfo_; + std::vector<RewriteInfo> rinfo_; + + /// Maintain info about nodes to add workspace edge + std::vector<WorkSpaceInfo> wsinfo_; + + /// Maintain info to be merged + std::vector<MergeInfo> minfo_; + + /// Maintain info about nodes to rewrite + static std::vector<ContextInfo> cinfo_; /// Hash table to maintain nodes visited in the graph. std::unordered_set<const Node*> visited_nodes_; @@ -209,6 +417,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // Mark the node as rewritten inline void MarkRewrittenNode(Node* n) { visited_nodes_.insert(n); } + // Clear all visited nodes + inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); } + // Get the name of Mkl op from original TensorFlow op // We prefix 'Mkl' to the original op to get Mkl op. // TODO(nhasabni) We should move this to mkl_util.h. @@ -218,6 +429,71 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return string(kMklOpPrefix) + name; } + // Return a node that can be merged with input node 'n' + // + // @return pointer to the node if we can find such a + // node. Otherwise, it returns nullptr. + Node* CheckForNodeMerge(const Node* n) const; + + // Merge predecessor node with its successor. + // Currently, we merge Conv2D with BiasAdd only. + // + // Input nodes succ and pred may be deleted if the call to + // this function is successful. Attempt to use the pointers + // after the call to function may result is undefined behaviors. + // + // @input g - input graph, succ - successor node, pred - predecessor node + // @return Status::OK(), if merging is successful and supported. + // Returns appropriate Status error code otherwise. + // Graph is updated in case nodes are merged. Otherwise, it is + // not updated. + Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred); + + // Check if the node 'n' has any applicable rewrite rule + // We check for 2 scenarios for rewrite. + // + // @return RewriteInfo* for the applicable rewrite rule + const RewriteInfo* CheckForNodeRewrite(const Node* n) const; + + // Default rewrite rule to be used in scenario 1 for rewrite. + // @return - true (since we want to always rewrite) + static bool AlwaysRewrite(const Node* n) { return true; } + // Rewrite rule that uses context-information for matching + // used in scenario 2. + // + // @input - Node 'n' for which to search for matching context + // @return - true if matching context is found; false otherwise. + static bool ContextMatchRewrite(const Node* n); + + // Helper function that searches the matching contextinfo for the node. + // Implements depth-first search in the data dependence graph for the + // gradient op in the backward direction. + // + // @input n - Node (gradient op) whose contextinfo is to be searched, + // fwdn - pointer to node from the forward pass that this node + // belongs to. fwdn cannot be NULL. + // @return Matching contextinfo in case a match is found; null otherwise. + // Also updates *fwdn with pointer to forward node that this context + // matches. + static const ContextInfo* SearchMatchingContext(const Node* n, + const Node** fwdn); + + // Rewrites input node to a new node specified by its matching rewrite info. + // + // Method first searches matching rewrite info for input node and then + // uses that info to rewrite. + // + // Input node may be deleted in case of rewrite. Attempt to use the node + // after the call can result in undefined behaviors. + // + // @input g - input graph, n - Node to be rewritten, + // ri - matching rewriteinfo + // @return Status::OK(), if the input node is rewritten; + // Returns appropriate Status error code otherwise. + // Graph is updated in case the input node is rewritten. + // Otherwise, it is not updated. + Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); + // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' // in graph 'g'. Original node is input in 'orign'. // @@ -230,28 +506,40 @@ class MklLayoutRewritePass : public GraphOptimizationPass { const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, NodeBuilder* nb, Node* orign); - // Rewrite Node 'n' in graph 'g' with rewrite information specified in 'ni' - // Returns Status::OK() if node rewrite is successful, otherwise returns - // appropriate error status - Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const NodesInfo& ni); + // Add workspace edge on the input or output side of Node 'orign' by using + // NodeBuilder 'nb' for the new node provided. If 'orign' does not dictate + // adding workspace edge then do not add it. + void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orign, + NodeBuilder* nb); // Functions specific to operators to copy attributes // We need operator-specific function to copy attributes because the framework // does not provide any generic function for it. - static void CopyAttrsConv2D(Node* orign, NodeBuilder* nb); + static void CopyAttrsConv2D(const Node* orign, NodeBuilder* nb); + static void CopyAttrsBiasAddGrad(const Node* orign, NodeBuilder* nb); + static void CopyAttrsPooling(const Node* orign, NodeBuilder* nb); + static void CopyAttrsRelu(const Node* orign, NodeBuilder* nb); // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, // using node for original node 'orign' and return it in '*out'. // TODO(nhasabni) We should move this to mkl_util.h void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, Node* orign); + void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, + Node* orign); }; +std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_; + // We register Mkl rewrite pass for phase 1 in pre-placement group. // Do not change the ordering of the Mkl passes. REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1, MklLayoutRewritePass); +////////////////////////////////////////////////////////////////////////// +// Helper functions for creating new node +////////////////////////////////////////////////////////////////////////// + static void FillInputs(const Node* n, gtl::InlinedVector<Node*, 4>* control_edges, gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { @@ -273,47 +561,6 @@ static void FillInputs(const Node* n, } } -////////////////////////////////////////////////////////////////////////// - -// Macros to build new node with different number of inputs. -// We need this way because we need to specify all the inputs when -// building a node. Comment at core/graph/node_builder.h, line 85-86. - -#define SETUP_INPUTS1(nb, op1) \ - do { \ - nb->Input(op1.node, op1.index); \ - } while (0) - -#define SETUP_INPUTS2(nb, op1, op2) \ - do { \ - nb->Input(op1.node, op1.index); \ - nb->Input(op2.node, op2.index); \ - } while (0) - -#define SETUP_INPUTS3(nb, op1, op2, op3) \ - do { \ - nb->Input(op1.node, op1.index); \ - nb->Input(op2.node, op2.index); \ - nb->Input(op3.node, op3.index); \ - } while (0) - -#define SETUP_INPUTS4(nb, op1, op2, op3, op4) \ - do { \ - nb->Input(op1.node, op1.index); \ - nb->Input(op2.node, op2.index); \ - nb->Input(op3.node, op3.index); \ - nb->Input(op4.node, op4.index); \ - } while (0) - -#define SETUP_INPUTS5(nb, op1, op2, op3, op4, op5) \ - do { \ - nb->Input(op1.node, op1.index); \ - nb->Input(op2.node, op2.index); \ - nb->Input(op3.node, op3.index); \ - nb->Input(op4.node, op4.index); \ - nb->Input(op5.node, op5.index); \ - } while (0) - // TODO(nhasabni) We should move this to mkl_util.h. void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, Node* orign) { @@ -335,6 +582,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, // device as device of original // node. .Finalize(&**g, out)); + (*out)->set_assigned_device_name(orign->assigned_device_name()); } Status MklLayoutRewritePass::SetUpInputs( @@ -359,7 +607,7 @@ Status MklLayoutRewritePass::SetUpInputs( TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T)); // If this op has been rewritten, then its name must have been same as // Mkl op. - CHECK_EQ(mkl_layer_registry::IsMklLayer(n->type_string()), true); + CHECK_EQ(mkl_layer_registry::IsMklLayer(n->type_string(), T), true); // src slot number for Mkl tensor would be the one next to TF tensor // slot number. new_inputs.push_back(NodeBuilder::NodeOut(n, inputs[i].second + 1)); @@ -380,38 +628,140 @@ Status MklLayoutRewritePass::SetUpInputs( // N for Mkl tensors corresponding to each Tensorflow tensors. CHECK_EQ(new_inputs.size(), inputs.size() * 2); - // 2. Let's build the node with new inputs. - switch (new_inputs.size()) { - case 0: // We don't need to do anything for no input as we have - // already built node. - break; - case 1: - SETUP_INPUTS1(nb, new_inputs[0]); - break; - case 2: - SETUP_INPUTS2(nb, new_inputs[0], new_inputs[1]); - break; - case 3: - SETUP_INPUTS3(nb, new_inputs[0], new_inputs[1], new_inputs[2]); - break; - case 4: - SETUP_INPUTS4(nb, new_inputs[0], new_inputs[1], new_inputs[2], - new_inputs[3]); - break; - case 5: - SETUP_INPUTS5(nb, new_inputs[0], new_inputs[1], new_inputs[2], - new_inputs[3], new_inputs[4]); - break; - default: { - return Status(error::Code::UNIMPLEMENTED, - "Could not create node with given number of inputs"); - } + // 2. Let's add the new inputs. + for (auto ni : new_inputs) { + nb->Input(ni.node, ni.index); } return Status::OK(); } -void MklLayoutRewritePass::CopyAttrsConv2D(Node* orign, NodeBuilder* nb) { +////////////////////////////////////////////////////////////////////////// +// Helper functions related to workspace pass +////////////////////////////////////////////////////////////////////////// + +// TODO(nhasabni) We should move this to mkl_util.h. +void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( + std::unique_ptr<Graph>* g, Node** out, Node* orign) { + // We use a tensor of shape {1} and value 0 to represent + // dummy float tensor. We need this as a dummy workspace tensor. + // Workspace tensor has type float. + const DataType dt = DataTypeToEnum<float>::v(); + TensorProto proto; + proto.set_dtype(dt); + float zero[1] = {0}; + proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)), + 4); + TensorShape dummy_shape({1}); + dummy_shape.AsProto(proto.mutable_tensor_shape()); + TF_CHECK_OK( + NodeBuilder((*g)->NewName("DMT"), "Const") + .Attr("value", proto) + .Attr("dtype", dt) + .Device(orign->def().device()) // We place this node on same + // device as device of original + // node. + .Finalize(&**g, out)); + (*out)->set_assigned_device_name(orign->assigned_device_name()); +} + +void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, + Node* orign, + NodeBuilder* nb) { + bool workspace_edge_added = false; + DataType T; + TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); + for (auto ws : wsinfo_) { + if (orign->type_string() == ws.fwdop && + mkl_layer_registry::IsMklLayer(GetMklOpName(orign->type_string()), T)) { + // If this op is a fwd op, then we need to check if there is an + // edge from this node's fwdslot to bwdop's bwdslot. If there is + // an edge, then we just add an attribute on this node for setting + // workspace_passed to true. We don't add actual workspace edge + // in this node. Actual workspace edge gets added in the backward + // op for this node. + for (const Edge* e : orign->out_edges()) { + if (e->src_output() == ws.fwdslot && + e->dst()->type_string() == ws.bwdop && + e->dst_input() == ws.bwdslot) { + nb->Attr("workspace_enabled", true); + VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " + << orign->type_string(); + workspace_edge_added = true; + // We found the edge that we were looking for, so break. + break; + } + } + + if (!workspace_edge_added) { + // If we are here, then we did not find backward operator for this + // node. + nb->Attr("workspace_enabled", false); + } + } else if (orign->type_string() == ws.bwdop && + mkl_layer_registry::IsMklLayer( + GetMklOpName(orign->type_string()), T)) { + // If this op is a bwd op, then we need to add workspace edge and + // it's Mkl tensor edge between its corresponding fwd op and this + // op. Corresponding fwd op is specified in 'fwdop' field of + // workspace info. fwdslot and bwdslot in workspace info specify + // an edge between which slots connect forward and backward op. + // Once all these criteria match, we add a workspace edge between + // wsfwdslot and wsbwdslot. It's corresponding Mkl tensor is added + // in wsfwdslot+1 and wsbwdslot+1. + for (const Edge* e : orign->in_edges()) { + if (e->src_output() == ws.fwdslot && + // We would have rewritten the forward op, so we need to use + // GetMklOpName call to get its Mkl name. + e->src()->type_string() == GetMklOpName(ws.fwdop) && + e->dst_input() == ws.bwdslot) { + nb->Attr("workspace_enabled", true); + // Add workspace edge between fwd op and bwd op. + nb->Input(e->src(), ws.wsfwdslot); + // Add Mkl tensor edge for workspace edge between fwd op and bwd op. + nb->Input(e->src(), ws.wsfwdslot + 1); + // In terms of input ordering, we add these calls to add Input + // here because workspace edge (and its Mkl tensor) is the last + // edge in the fwdop and bwdop. So all inputs before workspace + // tensor have been added by SetUpInputs function. + VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " + << orign->type_string(); + workspace_edge_added = true; + // We found the edge that we were looking for, so break. + break; + } + } + + // If we are here means we did not find fwd op that feeds to this + // bwd op. So in this case, we need to generate dummy tensors for + // workspace input and Mkl tensor for workspace, and set + // workspace_enabled to false. + if (!workspace_edge_added) { + nb->Attr("workspace_enabled", false); + Node* dmt_ws = nullptr; // Dummy tensor for workspace + Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace + GetDummyWorkspaceTensorNode(g, &dmt_ws, orign); + GetDummyMklTensorNode(g, &dmt_mkl_ws, orign); + CHECK_NOTNULL(dmt_ws); + CHECK_NOTNULL(dmt_mkl_ws); + nb->Input(dmt_ws, 0); // We add dummy tensor as workspace tensor. + nb->Input(dmt_mkl_ws, 0); // We add dummy tensor as Mkl + // tensor for workspace tensor. + VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " + << orign->type_string(); + } + } else { + // If this node does not match any workspace info, then we do not + // do anything special for workspace propagation for it. + } + } +} + +////////////////////////////////////////////////////////////////////////// +// Op-specific functions to copy attributes from old node to new node +////////////////////////////////////////////////////////////////////////// + +void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) { DataType T; string data_format; string padding; @@ -433,19 +783,280 @@ void MklLayoutRewritePass::CopyAttrsConv2D(Node* orign, NodeBuilder* nb) { nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); } +void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign, + NodeBuilder* nb) { + DataType T; + string data_format; + std::vector<int32> strides; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides)); + TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("strides", strides); + nb->Attr("data_format", data_format); +} + +void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign, + NodeBuilder* nb) { + DataType T; + string data_format; + string padding; + std::vector<int32> ksize, strides; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orign->def(), "ksize", &ksize)); + TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides)); + TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding)); + TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format)); + + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("ksize", ksize); + nb->Attr("strides", strides); + nb->Attr("padding", padding); + nb->Attr("data_format", data_format); +} + +void MklLayoutRewritePass::CopyAttrsRelu(const Node* orign, NodeBuilder* nb) { + DataType T; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T)); + + // Add attributes to new node. + nb->Attr("T", T); +} + +////////////////////////////////////////////////////////////////////////// +// Helper functions related to node merge pass +////////////////////////////////////////////////////////////////////////// + +Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { + // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite + // once we support BiasAddGrad as Mkl layer. + + // Search for all matching mergeinfo. + // We allow more than one match for extensibility. + std::vector<const MergeInfo*> matching_mi; + for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { + if (a->type_string() == mi->succ) { + matching_mi.push_back(&*mi); + } + } + + for (const MergeInfo* mi : matching_mi) { + const int N_in = a->num_inputs(); + if (mi->op >= N_in) { + continue; + } + + // Get the control edges and input of node + gtl::InlinedVector<Node*, 4> a_control_edges; + gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); + FillInputs(a, &a_control_edges, &a_in); + + // Get operand op of the operator + Node* b = nullptr; + b = a_in[mi->op].first; + if (b == nullptr || (b->type_string() != mi->pred)) { + // NOTE: Should the first check be assert? + continue; + } + + gtl::InlinedVector<Node*, 4> b_control_edges; + gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in); + FillInputs(b, &b_control_edges, &b_in); + + // Shouldn't merge if a and b have different control edges. + if (a_control_edges != b_control_edges) { + continue; + } else { + // We found a match. + return b; + } + } + + return nullptr; +} + +Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, + Node* pred) { + CHECK_NOTNULL(succ); + CHECK_NOTNULL(pred); + + if (succ->type_string() == csinfo_.biasadd && + pred->type_string() == csinfo_.mklconv2d) { + // 1. Get all attributes from input nodes. + DataType T_pred, T_succ; + string padding; + std::vector<int32> strides; + string data_format_pred, data_format_succ; + bool use_cudnn_on_gnu; + TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); + TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); + TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); + TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); + TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); + TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); + TF_CHECK_OK( + GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); + // We check to ensure that data formats of both succ and pred are same. + // We expect them to be same, so we can enforce this as assert. + // But assert can be too strict, so we enforce this as a check. + // If the check fails, then we do not merge two nodes. + // We also do same check for devices. + if (data_format_pred != data_format_succ || T_pred != T_succ || + pred->assigned_device_name() != succ->assigned_device_name() || + pred->def().device() != succ->def().device()) { + return Status(error::Code::INVALID_ARGUMENT, + "data_format or T attribute or devices of Conv2D and " + "BiasAdd do not match. Will skip node merge optimization"); + } + + const int succ_num = succ->num_inputs(); + gtl::InlinedVector<Node*, 4> succ_control_edges; + gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); + FillInputs(succ, &succ_control_edges, &succ_in); + + const int pred_num = pred->num_inputs(); + gtl::InlinedVector<Node*, 4> pred_control_edges; + gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); + FillInputs(pred, &pred_control_edges, &pred_in); + + // We need to ensure that there is only 1 edge between Conv2D and AddBias. + // Otherwise, merging is semantically incorrect. + if (pred->out_edges().size() != 1) { + return Status(error::Code::INVALID_ARGUMENT, + "Conv2D has multiple outputs." + "Will skip node merge optimization"); + } + + for (const Edge* e : pred->out_edges()) { + if (e->dst() != succ) { + return Status(error::Code::INVALID_ARGUMENT, + "Conv2D does not feed to BiasAdd." + "Will skip node merge optimization"); + } + } + + // 2. Get inputs from both the nodes. + // Find the 2 inputs from the conv and the bias from the add Bias. + // Get operand 0, 1 of conv2D and their Mkl tensors. + CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs. + // Get operand 1 of add_bias + // BiasAdd must have 2 inputs: Conv, bias + CHECK_EQ(succ->in_edges().size(), 2); + Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3 + int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0. + GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node + // as BiasAdd does not have Mkl tensor as input. + CHECK_NOTNULL(oper3_mkl); + + // We will use the node name of BiasAdd as the name of new node + // Build new node. We use same name as original node, but change the op + // name. + NodeBuilder nb(succ->name(), csinfo_.mklconv2dwithbias); + nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D + nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1 + nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D + nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 + nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd + nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd + + // Copy attributes from Conv2D to Conv2DWithBias. + CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); + + // Copy the device assigned to old node to new node. + nb.Device(succ->def().device()); + + // Create node. + Node* newn; + nb.Finalize(&**g, &newn); + CHECK_NOTNULL(newn); + + // Set the Mkl layer label for this op. + newn->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel); + + // Incoming edges are fixed, we will fix the outgoing edges now. + for (const Edge* e : succ->out_edges()) { + (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input()); + } + + // Copy device assigned to old node to new node. + // It's ok to use pred or succ as we have enforced a check that + // both have same device assigned. + newn->set_assigned_device_name(pred->assigned_device_name()); + + VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() + << ", and node: " << succ->DebugString() + << ", into node:" << newn->DebugString(); + + (*g)->RemoveNode(succ); + (*g)->RemoveNode(pred); + MarkRewrittenNode(newn); + + return Status::OK(); + } + + return Status(error::Code::UNIMPLEMENTED, + "Unimplemented case for node merge optimization."); +} + +////////////////////////////////////////////////////////////////////////// +// Helper functions for node rewrite +////////////////////////////////////////////////////////////////////////// + Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign, - const NodesInfo& ni) { - VLOG(1) << "MKLLayoutRewritePass: Original node:" << orign->DebugString(); + const RewriteInfo* ri) { + CHECK_NOTNULL(ri); + CHECK_NOTNULL(orign); + + VLOG(1) << "MklLayoutRewritePass: Original node:" << orign->DebugString(); + + // Check if this is scenario 2 (context-based rewrite). + // Get the matching ContextInfo if it is. + const Node* fwdn = nullptr; + const ContextInfo* ci = nullptr; + bool is_context_based_rewrite = false; + if ((ci = SearchMatchingContext(orign, &fwdn)) != nullptr) { + CHECK_NOTNULL(fwdn); + is_context_based_rewrite = true; + + // Sanity checks for context-based rewrite (if any) + if (orign->type_string() == csinfo_.biasaddgrad && + ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) { + DataType orig_T, ctx_T; + string orig_data_format, ctx_data_format; + TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &orig_T)); + TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &orig_data_format)); + TF_CHECK_OK(GetNodeAttr(fwdn->def(), "T", &ctx_T)); + TF_CHECK_OK(GetNodeAttr(fwdn->def(), "data_format", &ctx_data_format)); + + if (orig_data_format != ctx_data_format || orig_T != ctx_T || + orign->assigned_device_name() != fwdn->assigned_device_name() || + orign->def().device() != fwdn->def().device()) { + return Status( + error::Code::INVALID_ARGUMENT, + "data_format or T attribute or devices of BiasAddGrad and " + "Conv2D do not match. Will skip node rewrite optimization"); + } + } + } // Get all inputs. const int num = orign->num_inputs(); - CHECK_EQ(num, ni.numins); + CHECK_EQ(num, ri->numins); gtl::InlinedVector<Node*, 4> control_edges; gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num); FillInputs(orign, &control_edges, &inputs); // Build new node. We use same name as original node, but change the op name. - NodeBuilder nb(orign->name().c_str(), ni.newname.c_str()); + NodeBuilder nb(orign->name().c_str(), ri->newname.c_str()); // Copy user-specified device assigned to original node to new node. nb.Device(orign->def().device()); // Set up new inputs to the rewritten node. @@ -453,20 +1064,48 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign, if (s != Status::OK()) { return s; } - // Copy attributes from original node to new node. - ni.copyattrs(orign, &nb); + + // Copy attributes from original node to new node (for scenario 1). + // For context-based rewrite, we use context to copy the attributes. + if (is_context_based_rewrite) { + if (orign->type_string() == csinfo_.biasaddgrad && + ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) { + CHECK_NOTNULL(fwdn); + ri->copyattrs(fwdn, &nb); + } else { + return Status(error::Code::UNIMPLEMENTED, + "Unimplemented case for node rewrite optimization."); + } + } else { + ri->copyattrs(const_cast<const Node*>(orign), &nb); + } // Set the Mkl layer label for this op. nb.Attr("_kernel", mkl_layer_registry::kMklLayerLabel); - Node* newn = nullptr; + + // Add workspace edge to this node if needed. + // We add workspace edge only for MaxPool, LRN and BatchNorm. + AddWorkSpaceEdgeIfNeeded(g, orign, &nb); // Finalize graph and get new node. + Node* newn = nullptr; TF_CHECK_OK(nb.Finalize(&**g, &newn)); CHECK_NOTNULL(newn); // Incoming edges from 'orign' node to new 'newn' node are already copied // in BuildNode. Copy outgoing edges from 'orign' node to new 'newn' node. + // Since the output also follows same ordering among Tensorflow tensors and + // Mkl tensors. We need to connect Tensorflow tensors appropriately. + // Specifically, nth output of original node will become 2*nth output of + // Mkl node. GetTensorDataIndex provides this mapping function. for (const Edge* e : orign->out_edges()) { - (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input()); + // We need to handle control-edges by using their original slot number. + // Generally, -1 is reserved for control slot. + if (e->src_output() < 0) { + (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input()); + } else { + (*g)->AddEdge(newn, GetTensorDataIndex(e->src_output()), e->dst(), + e->dst_input()); + } } // Copy the runtime device assigned from original code to new node. @@ -476,10 +1115,123 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign, (*g)->RemoveNode(orign); MarkRewrittenNode(newn); - VLOG(1) << "MKLLayoutRewritePass: New node:" << newn->DebugString(); + VLOG(1) << "MklLayoutRewritePass: New node:" << newn->DebugString(); return Status::OK(); } +const MklLayoutRewritePass::ContextInfo* +MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) { + CHECK_NOTNULL(n); + CHECK_NOTNULL(fwdn); + *fwdn = nullptr; + + // Search for matching contextinfo based on node name. + // There could be more than one matching contextinfos. + bool is_matching_cinfo_found = false; + std::vector<const ContextInfo*> mci; + for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { + if (n->type_string() == ci->node) { + mci.push_back(&*ci); + is_matching_cinfo_found = true; + } + } + // If no matching contextinfo is found, return immediately. + if (!is_matching_cinfo_found) { + return nullptr; + } + + VLOG(1) << "MklLayoutRewritePass: Searching graph for: " << n->type_string() + << " in backwards."; + + // Now we will check for forward op name for context info in data + // flow graph. Get the max hops we should search for the fwd node. + // We are now going to search (breadth-first) backwards in data + // dependence graph (for up to max hops) from n for the node + // specified in fwd. + // queue to maintain nodes to be visited and depth info for + // breadth-first search + std::queue<std::pair<const Node*, int>> nqueue; + const Node* curr_node = n; + size_t curr_depth = 0; + nqueue.push(std::make_pair(curr_node, curr_depth)); + + while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) { + std::pair<const Node*, int> curr_pair = nqueue.front(); + nqueue.pop(); + + std::set<const Node*> visited_nodes; + curr_node = curr_pair.first; + curr_depth = curr_pair.second; + CHECK_NOTNULL(curr_node); + + VLOG(1) << "MklLayoutRewritePass: Visiting node: " + << curr_node->type_string() << " at depth: " << curr_depth + << " for node: " << n->type_string(); + + // If we find a match, we return immediately. + for (const ContextInfo* ci : mci) { + if (curr_node->type_string() == ci->fwd) { + *fwdn = curr_node; + return ci; + } + } + + // Else we explore backward edges from current node. + // Add the source nodes of all incoming edges of the node to the queue. + for (const Edge* e : curr_node->in_edges()) { + // We do not visit already visited node. + if (visited_nodes.find(e->src()) == visited_nodes.end()) { + // Depth of these nodes is 1 more than the depth of current node. + nqueue.push(std::make_pair(e->src(), curr_depth + 1)); + visited_nodes.insert(e->src()); + } + } + } /* while */ + + return nullptr; +} + +bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) { + const Node* fwdn = nullptr; + return SearchMatchingContext(n, &fwdn) != nullptr; +} + +const MklLayoutRewritePass::RewriteInfo* +MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { + CHECK_NOTNULL(n); + + // First check if node along with its type is supported by MKL layer. + // We do not want to rewrite an op into Mkl op if types are not supported. + // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to + // MklRelu if type is INT32. + DataType T; + if (!GetNodeAttr(n->def(), "T", &T).ok()) { + return nullptr; + } + if (!mkl_layer_registry::IsMklLayer(GetMklOpName(n->type_string()), T)) { + return nullptr; + } + + // We support 2 types of node rewrites: + // 1. Rewriting BiasAddGrad depending on its context. + // 2. Rewriting an op to Mkl op always + // We return true if any of these 2 conditions is met. + + // Find matching RewriteInfo and then check that rewrite rule applies. + for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { + if (n->type_string().compare(ri->name) == 0 && ri->rewriterule(n)) { + return &*ri; + } + } + + // Else return not found. + return nullptr; +} + +/////////////////////////////////////////////////////////////////////////////// +// Run function for the pass +/////////////////////////////////////////////////////////////////////////////// + bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { bool result = false; CHECK_NOTNULL(g); @@ -494,40 +1246,46 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { continue; } - for (const NodesInfo& ni : ninfo_) { - DataType dtype = DT_INVALID; - // An op needs to have data type (T) attribute and its corresponding - // Mkl op name must be supported. - if (GetNodeAttr(n->def(), "T", &dtype) == Status::OK() && - mkl_layer_registry::IsMklLayer(GetMklOpName(n->type_string())) && - n->type_string().compare(ni.name) == 0) { - string node_name = n->name(); - string op_name = n->type_string(); - - VLOG(1) << "MKLLayoutRewritePass: Scheduled node " << node_name - << " with op " << op_name << " for rewrite using" - << " layout optimization."; - - if (RewriteNode(g, n, ni) == Status::OK()) { - VLOG(1) << "MKLLayoutRewritePass: Successfully rewrote node " - << node_name << " with op " << op_name - << " for Mkl layout optimization."; - result = true; - break; // We found matching nodesinfo so no need to search next. - } + const RewriteInfo* ri = nullptr; + Node* predn = nullptr; + // We will first search if node is to be rewritten + if ((ri = CheckForNodeRewrite(n)) != nullptr) { + string node_name = n->name(); + string op_name = n->type_string(); + + VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name + << " with op " << op_name << " for rewrite using" + << " layout optimization."; + + if (RewriteNode(g, n, ri) == Status::OK()) { + VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name + << " with op " << op_name << " for Mkl layout optimization."; + result = true; + } + } else if ((predn = CheckForNodeMerge(n)) != nullptr) { + // Otherwise, we will check if the node is to be merged. + string n1_name = n->name(); + string n2_name = predn->name(); + + VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " + << n2_name << " for merging"; + + if (MergeNode(g, n, predn) == Status::OK()) { + VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " + << n2_name; + result = true; } } } DumpGraph("After running MklLayoutRewritePass", &**g); + // Clear marked nodes as the same graph pass may be used multiple times. + UnMarkRewrittenNodes(); + return result; } -/////////////////////////////////////////////////////////////////////////////// -// Run function for the pass -/////////////////////////////////////////////////////////////////////////////// - bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { return MklLayoutRewritePass().RunPass(g); } |