diff options
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 2177 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass_test.cc | 1865 |
2 files changed, 1 insertions, 4041 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 7394b1cddf..42a35727db 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -45,2181 +45,6 @@ limitations under the License. namespace tensorflow { -#ifdef INTEL_MKL_ML_ONLY - -// This pass implements rewriting of graph to support following scenarios: -// (A) Merging nodes in the graph -// (B) Rewriting a node in the graph to a new node -// Rewrite happens under following 2 scenarios: -// 1) Propagating Mkl layout as an additional output tensor -// (we will loosely call a tensor that carries Mkl layout as Mkl tensor -// henceforth.) from every Mkl supported NN layer. -// 2) Context-based rewrite: This is needed in order to optimize -// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and -// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into -// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad. -// This is context-specific optimization, where the context is the -// forward operator that the BiasAddGrad corresponds to. -// -// Example of A : Merging nodes in the graph -// ----------------------------------------- -// Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: -// -// O = Conv2D(A, B) -// P = BiasAdd(O, C) -// -// We merge them into Conv2DWithBias as: -// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) -// -// The meaning of A_m, B_m and C_m is explained in B.1. -// -// Merge rules: -// - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ -// goes to BiasAdd. -// - Also, the intersection of attributes of both the nodes must have same -// values. -// - Both the nodes must have been assigned to same device (if any). -// -// Example of B.1 : Rewriting nodes to Mkl nodes -// --------------------------------------------- -// Consider a Relu node. Current definition of Relu node looks like: -// -// O = Relu(A) -// -// Relu has 1 input (A), and 1 output (O). -// -// This rewrite pass will generate a new graph node for Relu (new node is -// called MklRelu) as: -// -// O, O_m = MklRelu(A, A_m) -// -// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is -// same as input A of Relu; output O is same as output O of Relu. O_m is the -// additional output tensor that will be set by MklRelu, and it represents -// Mkl tensor corresponding to O -- in other words, O_m is some kind of -// metadata for O. A_m is additional input of Relu, and it represents metadata -// for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives -// this metadata from previous node in the graph. -// -// When a previous node in the graph is an Mkl node, A_m will represent a valid -// Mkl tensor. But when a previous node is not an Mkl node, A_m will represent -// a dummy Mkl tensor. -// -// Rewriting rules: -// - Selection of a node for rewriting happens by registering the op type of -// the node with the rewriting pass. If the op type is not registered, then -// all nodes of this op type will not be rewritten. -// - Number of inputs after rewriting: -// Since for every input Tensorflow tensor, the rewritten node gets Mkl -// tensor(s), rewritten node gets 2*N inputs, where N is the number of -// inputs for the original node. -// - Number of outputs after rewriting: -// Since for every output Tensorflow tensor, the rewritten node generates -// Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the -// number of outputs of the original node. -// - Ordering of Tensorflow tensors and Mkl tensors: -// Since every rewritten node generates twice the number of inputs and -// outputs, one could imagine various orderings among Tensorflow tensors -// and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as -// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m -// in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m -// order. Among N inputs one can get N! permutations. -// -// So the question is: which order do we follow? We support 2 types of -// orderings: (1) interleaved, and (2) contiguous. Interleaved ordering -// follows an intuitive order where an Mkl tensor follows the -// corresponding Tensorflow tensor immediately. In the context of the -// above example, it will be: A, A_m, B, B_m. Note that the ordering rule -// applies to both the inputs and outputs. Contiguous ordering means -// all the Tensorflow tensors are contiguous followed by all the Mkl -// tensors. We use contiguous ordering as default. -// -// Graph rewrite algorithm: -// Algorithm: Graph Rewrite -// Input: Graph G, Names of the nodes to rewrite and their new names -// Output: Modified Graph G' if the nodes are modified, G otherwise. -// Start: -// N = Topological_Sort(G) // N is a set of nodes in toposort order. -// foreach node n in N -// do -// if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. -// then -// E = set of <incoming edge and its src_output slot> of n -// E' = {} // a new set of edges for rewritten node -// foreach <e,s> in E -// do -// E' U {<e,s>} // First copy edge which generates Tensorflow -// // tensor as it is -// m = Source node of edge e -// if Is_Rewritten(m) // Did we rewrite this node in this pass? -// then -// E' U {<m,s+1>} // If yes, then m will generate an Mkl -// // tensor as an additional output. -// else -// d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy -// // Mkl tensor. -// E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. -// fi -// done -// n' = Build_New_Node(G,new_name,E') -// Mark_Rewritten(n') // Mark the new node as being rewritten. -// fi -// done -// -// Explanation: -// For graph rewrite, we visit nodes of the input graph in the -// topological sort order. With this ordering, we visit nodes in the -// top-to-bottom fashion. We need this order because while visiting a -// node we want that all of its input nodes are visited and rewritten if -// applicable. This is because if we need to rewrite a given node -// then all of its input nodes need to be fixed (in other words they -// cannot be deleted later.) -// -// While visiting a node, we first check if the op type of the node is -// an Mkl op. If it is, then we rewrite that node after constructing -// new inputs to the node. If the op type of the node is not Mkl op, -// then we do not rewrite that node. -// -// Handling workspace propagation for certain ops: -// -// Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require -// passing of a workspace from their respective forward ops. Workspace -// tensors provide memory for storing results of intermediate operations -// which are helpful in backward propagation. TensorFlow does not have -// a notion of a workspace and as a result does not allow producing -// additional outputs from these forward ops. For these ops, we need -// to add 2 extra edges between forward ops and their corresponding -// backward ops - the first extra edge carries a workspace tensor and -// the second one carries an Mkl tensor for the workspace tensor. -// -// Example: -// -// Typical graph for MaxPool and its gradient looks like: -// -// A = MaxPool(T) -// B = MaxPoolGrad(X, A, Y) -// -// We will transform this graph to propagate the workspace as: -// (with the contiguous ordering) -// -// A, W, A_m, W_m = MklMaxPool(T, T_m) -// B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) -// -// Here W is the workspace tensor. Transformed tensor names with the -// suffix _m are Mkl tensors, and this transformation has been done -// using the algorithm discussed earlier. The transformation for -// workspace propagation only adds extra outputs (W, W_m) for a forward -// op and connects them to the corresponding backward ops. -// -// Terms: -// -// Forward op name = name of the op in the forward pass -// where a workspace tensor originates (MaxPool in this example) -// Backward op name = name of the op in the backward pass that receives -// a workspace tensor from the forward op (MaxPoolGrad in the example) -// Slot = Position of the output or input slot that will be -// used by the workspace tensor (1 for MklMaxPool as W is the 2nd -// output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) -// -// Question: -// -// How do we associate a backward op to a forward op? There can be more -// than one op with the exact same name. -// -// In this example, we associate MaxPoolGrad with MaxPool. But there -// could be more than one MaxPool ops. To solve this problem, we look -// for _direct_ edge between a forward op and a backward op (tensor A is -// flowing along this edge in the example). -// -// How do we transform forward and backward ops when there is no direct -// edge between them? In such a case, we generate dummy tensors for -// workspace tensors. For the example, transformation of MaxPool will -// be exactly same as it would be when there is a direct edge between -// the forward and the backward op --- it is just that MaxPool won't -// generate any workspace tensor. For MaxPoolGrad, the transformation -// will also be same, but instead of connecting W and W_m with the -// outputs of MaxPool, we will produce dummy tensors for them, and we -// will set workspace_enabled attribute to false. -// -// Example of B.2 : Context-based node rewrite -// ------------------------------------------- -// Consider BiasAddGrad op as: -// -// O = _MklConv2D(A, B, C, A_m, B_m, C_m) -// P = BiasAddGrad(O) -// -// Then we rewrite it as: -// -// P = Conv2DWithBiasBackpropBias(O, O_m) -// -// Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending -// on the matching 'context'. The term context is loosely related to which -// forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then -// we consider it Conv2D context; if it is MatMul, then it is MatMul context. - -class MklLayoutRewritePass : public GraphOptimizationPass { - public: - MklLayoutRewritePass() { - // NOTE: names are alphabetically sorted. - csinfo_.addn = "AddN"; - csinfo_.avg_pool = "AvgPool"; - csinfo_.avg_pool_grad = "AvgPoolGrad"; - csinfo_.bias_add = "BiasAdd"; - csinfo_.bias_add_grad = "BiasAddGrad"; - csinfo_.concat = "Concat"; - csinfo_.concatv2 = "ConcatV2"; - csinfo_.conv2d = "Conv2D"; - csinfo_.conv2d_grad_input = "Conv2DBackpropInput"; - csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; - csinfo_.fused_batch_norm = "FusedBatchNorm"; - csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; - csinfo_.identity = "Identity"; - csinfo_.lrn = "LRN"; - csinfo_.lrn_grad = "LRNGrad"; - csinfo_.matmul = "MatMul"; - csinfo_.max_pool = "MaxPool"; - csinfo_.max_pool_grad = "MaxPoolGrad"; - csinfo_.mkl_conv2d = "_MklConv2D"; - csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; - csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; - csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; - csinfo_.mkl_conv2d_with_bias_backprop_bias = - "_MklConv2DWithBiasBackpropBias"; - csinfo_.relu = "Relu"; - csinfo_.relu_grad = "ReluGrad"; - csinfo_.reshape = "Reshape"; - csinfo_.split = "Split"; - // Element-wise ops. Ensure you also add any new ops to IsOpElementWise - // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the - // MklInputConversion op is added before it. - csinfo_.add = "Add"; - csinfo_.maximum = "Maximum"; - csinfo_.mul = "Mul"; - csinfo_.squared_difference = "SquaredDifference"; - csinfo_.sub = "Sub"; - // End - element-wise ops. See note above. - - // NOTE: names are alphabetically sorted. - rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), - CopyAttrsAddN, AddNRewrite, nullptr}); - rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), - CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.avg_pool, - mkl_op_registry::GetMklOpName(csinfo_.avg_pool), - CopyAttrsPooling, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.avg_pool_grad, - mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), - CopyAttrsPooling, AlwaysRewrite, nullptr}); - // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending - // on if context contains Conv2D. - rinfo_.push_back({csinfo_.bias_add_grad, - csinfo_.mkl_conv2d_with_bias_backprop_bias, - CopyAttrsBiasAddGrad, ContextMatchRewrite, - &biasaddgrad_conv2dwithbias_context_}); - // BiasAddGrad gets written into BiasAddGrad depending on if context - // contains MatMul. - rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul, - CopyAttrsBiasAddGrad, ContextMatchRewrite, - &biasaddgrad_matmul_context_}); - rinfo_.push_back({csinfo_.concat, - mkl_op_registry::GetMklOpName(csinfo_.concat), - CopyAttrsConcat, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.concatv2, - mkl_op_registry::GetMklOpName(csinfo_.concatv2), - CopyAttrsConcatV2, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.conv2d, - mkl_op_registry::GetMklOpName(csinfo_.conv2d), - CopyAttrsConv2D, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.conv2d_grad_filter, - mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), - CopyAttrsConv2D, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.conv2d_grad_input, - mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), - CopyAttrsConv2D, AlwaysRewrite, nullptr}); - - rinfo_.push_back({csinfo_.fused_batch_norm, - mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), - CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); - rinfo_.push_back( - {csinfo_.fused_batch_norm_grad, - mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), - CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.identity, - mkl_op_registry::GetMklOpName(csinfo_.identity), - CopyAttrsIdentity, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), - CopyAttrsLRN, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.lrn_grad, - mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), - CopyAttrsLRN, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.max_pool, - mkl_op_registry::GetMklOpName(csinfo_.max_pool), - CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr}); - rinfo_.push_back({csinfo_.max_pool_grad, - mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), - CopyAttrsPooling, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.maximum, - mkl_op_registry::GetMklOpName(csinfo_.maximum), - CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), - CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), - CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.relu_grad, - mkl_op_registry::GetMklOpName(csinfo_.relu_grad), - CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.reshape, - mkl_op_registry::GetMklOpName(csinfo_.reshape), - CopyAttrsReshape, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.squared_difference, - mkl_op_registry::GetMklOpName(csinfo_.squared_difference), - CopyAttrsDataType, AlwaysRewrite, nullptr}); - rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), - CopyAttrsDataType, AlwaysRewrite, nullptr}); - - // Add info about which ops to add workspace edge to and the slots. - wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); - wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); - - // Add a rule for merging nodes - minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0, - csinfo_.mkl_conv2d_with_bias}); - - biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, - IsBiasAddGradInMatMulContext}; - - biasaddgrad_conv2dwithbias_context_ = { - csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, - IsBiasAddGradInConv2DWithBiasContext}; - - cinfo_.push_back(&biasaddgrad_matmul_context_); - cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); - } - - // Standard interface to run pass - Status Run(const GraphOptimizationPassOptions& options); - - // Helper function which does most of heavy lifting for rewriting - // Mkl nodes to propagate Mkl tensor as additional output - // - // Extracts common functionality between Run public interface and - // test interface. - // - // @return true, if and only if graph is mutated; false otherwise. - bool RunPass(std::unique_ptr<Graph>* g); - - /// Structure to specify the context information used in a node rewrite rule - typedef struct { - string node; // Name of the node to be rewritten - string fwd; // Name of the node in the forward pass that this node - // corresponds to - std::function<bool(const Node*, const Node**, void* c)> context_match_fn; - } ContextInfo; - - /// Structure to specify the name of an original node, its new name after - /// rewrite, the number of inputs to the original node, the function to - /// be used to copy attributes for the op, and the rule (if any) which - /// must hold for rewriting the node - typedef struct { - string name; // Original name of op of the node in the graph - string new_name; // New name of the op of the node in the graph - // A function handler to copy attributes from an old node to a new node. - std::function<void(const Node*, NodeBuilder*)> copy_attrs; - // A rule under which to rewrite this node - std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule; - // ContextInfo, if any, to be used for rewrite - ContextInfo* context; - } RewriteInfo; - - /// Structure to specify a forward op, a backward op, and the slot numbers - /// in the forward and backward ops where we will add a workspace edge. - typedef struct { - string fwd_op; // Name of a forward op in the graph - string bwd_op; // Name of a backward op in the graph - int fwd_slot; // Output slot in the forward op node where actual - // output tensor resides - int bwd_slot; // Input slot in the backward op node where actual - // input tensor resides - int ws_fwd_slot; // Output slot in the forward op node where workspace - // edge is added - int ws_bwd_slot; // Input slot in the backward op node where workspace - // edge is added - } WorkSpaceInfo; - - /// Structure to specify information used in node merge - typedef struct { - string pred; // Predecessor node string - string succ; // Successor node string - int op; // The operand no the predecessor node corresponds - // to the successor node - string new_node; // Name of the node after merge - } MergeInfo; - - /// Structure to store all constant strings - /// NOTE: names are alphabetically sorted. - typedef struct { - string addn; - string add; - string avg_pool; - string avg_pool_grad; - string bias_add; - string bias_add_grad; - string concat; - string concatv2; - string conv2d; - string conv2d_grad_input; - string conv2d_grad_filter; - string fused_batch_norm; - string fused_batch_norm_grad; - string identity; - string lrn; - string lrn_grad; - string matmul; - string max_pool; - string max_pool_grad; - string maximum; - string mkl_conv2d; - string mkl_conv2d_grad_input; - string mkl_conv2d_grad_filter; - string mkl_conv2d_with_bias; - string mkl_conv2d_with_bias_backprop_bias; - string mul; - string relu; - string relu_grad; - string reshape; - string split; - string squared_difference; - string sub; - } ConstStringsInfo; - - private: - /// Maintain info about nodes to rewrite - std::vector<RewriteInfo> rinfo_; - - /// Maintain info about nodes to add workspace edge - std::vector<WorkSpaceInfo> wsinfo_; - - /// Maintain info about nodes to be merged - std::vector<MergeInfo> minfo_; - - /// Maintain info about nodes to rewrite - static std::vector<ContextInfo*> cinfo_; - - /// Maintain structure of constant strings - static ConstStringsInfo csinfo_; - - /// Context variables used in referencing rules - static ContextInfo biasaddgrad_matmul_context_; - static ContextInfo biasaddgrad_conv2dwithbias_context_; - - private: - // Is OpDef::ArgDef a list type? It could be N * T or list(type). - // Refer to opdef.proto for details of list type. - inline bool ArgIsList(const OpDef::ArgDef& arg) const { - return !arg.type_list_attr().empty() || !arg.number_attr().empty(); - } - - // Get length of a list in 'n' if 'arg' is of list type. Refer to - // description of ArgIsList for definition of list type. - inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) { - CHECK_EQ(ArgIsList(arg), true); - int N = 0; - const string attr_name = !arg.type_list_attr().empty() - ? arg.type_list_attr() - : arg.number_attr(); - if (!arg.type_list_attr().empty()) { - std::vector<DataType> value; - TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); - N = value.size(); - } else { - TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); - } - return N; - } - - // Can op represented by node 'n' run on DEVICE_CPU? - // Op can run on CPU with MKL if the runtime assigned device or the - // user requested device contains device CPU, or both are empty. - bool CanOpRunOnCPUDevice(const Node* n) { - bool result = true; - string reason; - - // Substring that should be checked for in device name for CPU device. - const char* const kCPUDeviceSubStr = "CPU"; - - // If Op has been specifically assigned to a non-CPU device, then No. - if (!n->assigned_device_name().empty() && - !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) { - result = false; - reason = "Op has been assigned a runtime device that is not CPU."; - } - - // If user has specifically assigned this op to a non-CPU device, then No. - if (!n->def().device().empty() && - !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) { - result = false; - reason = "User has assigned a device that is not CPU."; - } - - if (result == false) { - VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " - << n->type_string() << ", reason: " << reason; - } - - // Otherwise Yes. - return result; - } - - // Return a node that can be merged with input node 'n' - // - // @return pointer to the node if we can find such a - // node. Otherwise, it returns nullptr. - Node* CheckForNodeMerge(const Node* n) const; - - // Merge predecessor node with its successor. - // Currently, we merge Conv2D with BiasAdd only. - // - // Input nodes succ and pred may be deleted if the call to - // this function is successful. Attempt to use the pointers - // after the call to function may result in undefined behaviors. - // - // @input g - input graph, succ - successor node, pred - predecessor node - // @return Status::OK(), if merging is successful and supported. - // Returns appropriate Status error code otherwise. - // Graph is updated in case nodes are merged. Otherwise, it is - // not updated. - Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred); - - // Check if the node 'n' has any applicable rewrite rule - // We check for 2 scenarios for rewrite. - // - // @return RewriteInfo* for the applicable rewrite rule - const RewriteInfo* CheckForNodeRewrite(const Node* n) const; - - // Default rewrite rule to be used in scenario 1 for rewrite. - // @return - true (since we want to always rewrite) - static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) { - return true; - } - - // Check if we are performing pooling on depth or batch. If it is, then we - // do not rewrite MaxPool node to Mkl version. - // @return - true (if it is not a depth/batch wise pooling case); - // false otherwise. - static bool NonDepthBatchWisePoolRewrite(const Node* n, - const ContextInfo* c) { - CHECK_NOTNULL(n); - - string data_format_str; - TensorFormat data_format; - std::vector<int32> ksize, strides; - CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true); - CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true); - CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true); - CHECK_EQ(FormatFromString(data_format_str, &data_format), true); - - // Condition that specifies non-batch-wise and non-depth-wise pooling. - if (GetTensorDim(ksize, data_format, 'N') == 1 && - GetTensorDim(strides, data_format, 'N') == 1 && - GetTensorDim(ksize, data_format, 'C') == 1 && - GetTensorDim(strides, data_format, 'C') == 1) { - return true; - } - - return false; - } - - static bool AddNRewrite(const Node* n, const ContextInfo* c) { - CHECK_NOTNULL(n); - - int num; - CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true); - - // Condition that specifies non-batch-wise and non-depth-wise pooling. - if (num == 2) { - return true; - } - - return false; - } - // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node - // specified in contextinfo 'ci'. Function updates fwd_node to point - // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias. - // - // Association checks for one of the following graphs: - // - // Graph A: - // - // _ = Conv2DWithBias(F, I, _) - // .. - // _ = Conv2DBackpropFilter(F, _, G) - // _ = Conv2DBackpropInput(_, I, G) - // _ = BiasAddGrad(G) - // - // OR - // - // Graph B: - // - // _ = Conv2DWithBias(F, _, _) - // .. - // _ = Conv2DBackpropFilter(F, _, G) - // _ = BiasAddGrad(G) - // - // Here F, G, and I are graph nodes; _ represents graph nodes that we - // don't care here. - // - // @return - true (if BiasAddGrad is associated with Conv2DWithBias); - // false otherwise. - static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n, - const Node** fwd_node, - void* ci) { - CHECK_NOTNULL(n); - CHECK_NOTNULL(fwd_node); - CHECK_NOTNULL(ci); - *fwd_node = nullptr; - - CHECK_EQ(n->type_string(), csinfo_.bias_add_grad); - - // Get the only 1 input of BiasAddGrad. - CHECK_EQ(n->num_inputs(), 1); - const Node* bias_add_grad_inp = nullptr; - TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp)); - CHECK_NOTNULL(bias_add_grad_inp); - - // Check if this input also goes to BackpropFilter and BackpropInput - // as 3rd input. - bool found_backprop_input = false; - bool found_backprop_filter = false; - Node* backprop_filter_node = nullptr; - Node* backprop_input_node = nullptr; - - for (const Edge* e : bias_add_grad_inp->out_edges()) { - Node* third_input = nullptr; - if (e->dst()->type_string() == csinfo_.conv2d_grad_input || - e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) { - // Third input (index 2) of BackpropInput - TF_CHECK_OK(e->dst()->input_node(2, &third_input)); - // Third input (index 2) of BackpropInput must be same as the input - // of BiasAddGrad. - if (third_input == bias_add_grad_inp) { - found_backprop_input = true; - backprop_input_node = e->dst(); - } - } - - if (e->dst()->type_string() == csinfo_.conv2d_grad_filter || - e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) { - // Third input (index 2) of BackpropFilter - TF_CHECK_OK(e->dst()->input_node(2, &third_input)); - // Third input (index 2) of BackpropFilter must be same as the input - // of BiasAddGrad. - if (third_input == bias_add_grad_inp) { - found_backprop_filter = true; - backprop_filter_node = e->dst(); - } - } - - // If we found both the nodes, then we can stop the search. - if (found_backprop_input && found_backprop_filter) { - break; - } - } - - // If BackpropFilter node is not found, then this is not - // Conv2DWithBias context. For 2nd graph in the example above, only - // BackpropFilter would be present. - if (!found_backprop_filter) { - return false; - } - - // Otherwise, we found the nodes. - CHECK_NOTNULL(backprop_filter_node); - if (found_backprop_input) { - CHECK_NOTNULL(backprop_input_node); - } - - // Now that we confirmed that this is Conv2DWithBias context, we need to - // get access to the forward node (Conv2DWithBias). 2nd input of - // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st - // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter - // (This comes from definition of gradient computation for Conv2D). - if (found_backprop_input) { - // Graph A in the example. - Node* second_inp_of_input = nullptr; - Node* first_inp_of_filter = nullptr; - TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input)); - TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); - CHECK_NOTNULL(second_inp_of_input); - CHECK_NOTNULL(first_inp_of_filter); - - // Now we need to find out Conv2DWithBias node from these input nodes. - // Conv2DWithBias node is the node that accepts both the nodes - // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots. - for (const Edge* fe : first_inp_of_filter->out_edges()) { - if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && - fe->dst_input() == 0) { - for (const Edge* ie : second_inp_of_input->out_edges()) { - if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && - ie->dst_input() == 1 && fe->dst() == ie->dst()) { - VLOG(1) << "MklLayoutRewritePass: found " - << fe->dst()->DebugString() - << " as the forward node for matching context, backward" - << " node is: " << n->DebugString(); - *fwd_node = fe->dst(); - return true; - } - } - } - } - } else { - // We did not find BackpropInput, so we work with BackpropFilter only. - // Graph B in the example. - Node* first_inp_of_filter = nullptr; - TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); - CHECK_NOTNULL(first_inp_of_filter); - - // Now we need to find out Conv2DWithBias node from first input of - // BackpropFIlter. Conv2DWithBias node is the node that accepts - // first_inp_of_filter in 1st input slot. - for (const Edge* fe : first_inp_of_filter->out_edges()) { - if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && - fe->dst_input() == 0) { - VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString() - << " as the forward node for matching context, backward" - << " node is: " << n->DebugString(); - *fwd_node = fe->dst(); - return true; - } - } - } - - return false; - } - - // Is BiasAddGrad node in 'n' is associated with MatMul node - // specified in contextinfo 'ci'. Function does not update fwd_node. - // - // @return - true (if BiasAddGrad is associated with MatMul); - // false otherwise. - static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node, - void* ci) { - return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); - } - - // Rewrite rule that uses context-information for matching, - // used in scenario 2. - // - // @input - Node 'n' for which to search for matching context - // @input - The context 'c' under which to rewrite - // @return - true if we can rewrite node under context 'c'; - // false otherwise. - static bool ContextMatchRewrite(const Node* n, const ContextInfo* c); - - // Helper function that searches the matching contextinfo for the node. - // - // @input n - Node (gradient op) whose contextinfo is to be searched, - // fwd_node - pointer to node from the forward pass that this node - // belongs to. fwd_node cannot be NULL. - // @return Matching contextinfo in case a match is found; null otherwise. - // Also updates *fwd_node with pointer to forward node that this - // context matches. - static const ContextInfo* SearchMatchingContext(const Node* n, - const Node** fwd_node); - - // Rewrites input node to a new node specified by its matching rewrite info. - // - // Method first searches matching rewrite info for input node and then - // uses that info to rewrite. - // - // Input node may be deleted in case of rewrite. Attempt to use the node - // after the call can result in undefined behaviors. - // - // @input g - input graph, n - Node to be rewritten, - // ri - matching rewriteinfo - // @return Status::OK(), if the input node is rewritten; - // Returns appropriate Status error code otherwise. - // Graph is updated in case the input node is rewritten. - // Otherwise, it is not updated. - Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); - - // Get nodes that will feed a list of TF tensors to the new - // node that we are constructing. - // - // @input g - input graph, - // @input inputs - inputs to old node that we are using for constructing - // new inputs, - // @input input_idx - the index in the 'inputs' vector pointing to the - // current input that we have processed so far - // @output input_idx - index will be incremented by the number of nodes - // from 'inputs' that are processed - // @input list_length - The expected length of list of TF tensors - // @output output_nodes - the list of new nodes creating TF tensors - // - // @return None - void GetNodesProducingTFTensorList( - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, - int* input_idx, int list_length, - std::vector<NodeBuilder::NodeOut>* output_nodes); - - // Get nodes that will feed a list of Mkl tensors to the new - // node that we are constructing. - // - // @input g - input graph, - // @input orig_node - Original node that we are rewriting - // @input inputs - inputs to old node that we are using for constructing - // new inputs, - // @input input_idx - the index in the 'inputs' vector pointing to the - // current input that we have processed so far - // @output input_idx - index will be incremented by the number of nodes - // from 'inputs' that are processed - // @input list_length - The expected length of list of Mkl tensors - // @output output_nodes - the list of new nodes creating Mkl tensors - // - // @return None - void GetNodesProducingMklTensorList( - std::unique_ptr<Graph>* g, Node* orig_node, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, - int* input_idx, int list_length, - std::vector<NodeBuilder::NodeOut>* output_nodes); - - // Get a node that will feed an Mkl tensor to the new - // node that we are constructing. The output node could be (1) 'n' - // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor - // if 'n' is not an Mkl layer. - // - // @input g - input graph, - // @input orig_node - Original node that we are rewriting, - // @input n - Node based on which we are creating Mkl node, - // @input n_output_slot - the output slot of node 'n' - // which is feeding to the node that we are constructing - // @output mkl_node - the new node that will feed Mkl tensor - // @output mkl_node_output_slot - the slot number of mkl_node that - // will feed the tensor - // @return None - void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node, - Node* n, int n_output_slot, Node** mkl_node, - int* mkl_node_output_slot); - - // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' - // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are - // set up in contiguous fashion. 'workspace_tensors' carry graph nodes - // producing workspace edges if 'are_workspace_tensors_available' is true. - // Otherwise, 'workspace_tensors' is empty vector. - // - // For details, refer to 'Ordering of inputs after rewriting' section in the - // documentation above. - // - // Returns Status::OK() if setting up inputs is successful, otherwise - // returns appropriate status code. - int SetUpContiguousInputs( - std::unique_ptr<Graph>* g, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, - NodeBuilder* nb, Node* old_node, - std::vector<NodeBuilder::NodeOut>* workspace_tensors, - bool are_workspace_tensors_available); - - // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' - // in graph 'g'. Original node is input in 'orig_node'. - // - // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' - // section in the documentation above. - // - // Returns Status::OK() if setting up inputs is successful, otherwise - // returns appropriate status code. - Status SetUpInputs(std::unique_ptr<Graph>* g, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, - NodeBuilder* nb, Node* orig_node); - - // Add workspace edge on the input or output side of Node 'orig_node' by using - // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate - // adding workspace edge then do not add it. Workspace Tensorflow and Mkl - // tensors, if they need to be added, will be set into these tensors. - // If we set workspace tensors, then are_ws_tensors_added should be true. - void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node, - NodeBuilder* nb, - std::vector<NodeBuilder::NodeOut>* ws_tensors, - bool* are_ws_tensors_added); - - // Functions specific to operators to copy attributes - // We need operator-specific function to copy attributes because the framework - // does not provide any generic function for it. - // NOTE: names are alphabetically sorted. - static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); - - // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, - // using node for original node 'orig_node' and return it in '*out'. - // TODO(nhasabni) We should move this to mkl_util.h - void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, - Node* orig_node); - void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, - Node* orig_node); -}; - -MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; -MklLayoutRewritePass::ContextInfo - MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; -MklLayoutRewritePass::ContextInfo - MklLayoutRewritePass::biasaddgrad_matmul_context_; -std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; - -// We register Mkl rewrite pass for phase 1 in post partitioning group. -// We register it here so that we get a complete picture of all users of Mkl -// nodes. Do not change the ordering of the Mkl passes. -const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = - OptimizationPassRegistry::POST_PARTITIONING; -#ifdef ENABLE_MKL -REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); -#endif // ENABLE_MKL - -////////////////////////////////////////////////////////////////////////// -// Helper functions for creating new node -////////////////////////////////////////////////////////////////////////// - -static void FillInputs(const Node* n, - gtl::InlinedVector<Node*, 4>* control_edges, - gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { - control_edges->clear(); - for (const Edge* e : n->in_edges()) { - if (e->IsControlEdge()) { - control_edges->push_back(e->src()); - } else { - (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); - } - } - std::sort(control_edges->begin(), control_edges->end()); - if (n->op_def().is_commutative()) { - // For commutative inputs, we sort the input by the input Node* - // to get a canonical ordering (so that add(a,b) and add(b, a) will - // hash to the same value if is_commutative is true for 'add'). - std::sort(in->begin(), in->end()); - } -} - -void MklLayoutRewritePass::GetNodesProducingTFTensorList( - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, - int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { - CHECK_LT(*input_idx, inputs.size()); - CHECK_GT(list_length, 0); - CHECK_NOTNULL(output_nodes); - output_nodes->reserve(list_length); - - while (list_length != 0) { - CHECK_GT(list_length, 0); - CHECK_LT(*input_idx, inputs.size()); - Node* n = inputs[*input_idx].first; - int slot = inputs[*input_idx].second; - // If input node 'n' is just producing a single tensor at - // output slot 'slot' then we just add that single node. - output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); - (*input_idx)++; - list_length--; - } -} - -// TODO(nhasabni) We should move this to mkl_util.h. -void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, - Node** out, Node* orig_node) { - // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent - // dummy Mkl tensor. 8 = 2*size_t. - const DataType dt = DataTypeToEnum<uint8>::v(); - TensorProto proto; - proto.set_dtype(dt); - uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; - proto.set_tensor_content(string(reinterpret_cast<const char*>(zero), 8)); - TensorShape dummy_shape({8}); - dummy_shape.AsProto(proto.mutable_tensor_shape()); - TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // the same device as the - // device of the original - // node. - .Finalize(&**g, out)); - CHECK_NOTNULL(*out); // Make sure we got a valid object before using it - - // If number of inputs to the original node is > 0, then we add - // control dependency between 1st input (index 0) of the original node and - // the dummy Mkl node. This is needed because control-flow ops such as Enter, - // Merge, etc, require frame_name of the dummy Mkl node to be same as the - // rewritten node. Adding control edge between 1st input of the original node - // and the dummy Mkl node ensures that the dummy node is in the same frame - // as the original node. Choosing 1st input is not necessary - any input of - // the original node is fine because all the inputs of a node are always in - // the same frame. - if (orig_node->num_inputs() > 0) { - Node* orig_input0 = nullptr; - TF_CHECK_OK( - orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); - CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); - } - - (*out)->set_assigned_device_name(orig_node->assigned_device_name()); -} - -void MklLayoutRewritePass::GetNodesProducingMklTensorList( - std::unique_ptr<Graph>* g, Node* orig_node, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, - int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { - CHECK_LT(*input_idx, inputs.size()); - CHECK_GT(list_length, 0); - CHECK_NOTNULL(output_nodes); - output_nodes->reserve(list_length); - - while (list_length != 0) { - CHECK_GT(list_length, 0); - CHECK_LT(*input_idx, inputs.size()); - Node* n = inputs[*input_idx].first; - int slot = inputs[*input_idx].second; - // If 'n' is producing a single tensor, then create a single Mkl tensor - // node. - Node* mkl_node = nullptr; - int mkl_node_output_slot = 0; - GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, - &mkl_node_output_slot); - output_nodes->push_back( - NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); - (*input_idx)++; - list_length--; - } -} - -// Get an input node that will feed Mkl tensor to the new -// node that we are constructing. An input node could be (1) 'n' -// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor -// if 'n' is not an Mkl layer. -void MklLayoutRewritePass::GetNodeProducingMklTensor( - std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot, - Node** mkl_node, int* mkl_node_output_slot) { - CHECK_NOTNULL(n); - CHECK_NOTNULL(mkl_node); - CHECK_NOTNULL(mkl_node_output_slot); - - // If this is an MKL op, then it will create extra output for MKL layout. - DataType T; - if (GetNodeAttr(n->def(), "T", &T).ok() && - mkl_op_registry::IsMklOp(n->type_string(), T)) { - // If this is an MKL op, then it will generate an edge that will receive - // Mkl tensor from a node. - // output slot number for Mkl tensor would be N+slot number of TensorFlow - // tensor, where N is total number of TensorFlow tensors. - *mkl_node = n; - *mkl_node_output_slot = - GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); - } else { - // If we have not visited the node and rewritten it, then we need - // to create a dummy node that will feed a dummy Mkl tensor to this node. - // DummyMklTensor node has no input and generates only 1 output - // (dummy Mkl tensor) as output slot number 0. - GetDummyMklTensorNode(g, mkl_node, orig_node); - CHECK_NOTNULL(*mkl_node); - *mkl_node_output_slot = 0; - } -} - -int MklLayoutRewritePass::SetUpContiguousInputs( - std::unique_ptr<Graph>* g, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, - NodeBuilder* nb, Node* old_node, - std::vector<NodeBuilder::NodeOut>* workspace_tensors, - bool are_workspace_tensors_available) { - CHECK_NOTNULL(workspace_tensors); - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - - // TODO(nhasabni): Temporary solution to connect filter input of - // BackpropInput with the converted filter from Conv2D. - bool do_connect_conv2d_backprop_input_filter = false; - Node* conv2d_node = nullptr; - // Filter node is 2nd input (slot index 1) of Conv2D. - int kConv2DFilterInputSlotIdx = 1; - int kConv2DBackpropInputFilterInputSlotIdx = 1; - int kConv2DFilterOutputSlotIdx = 1; - if (old_node->type_string() == csinfo_.conv2d_grad_input) { - // We need to find Conv2D node from Conv2DBackpropInput. - // For that let's first find filter node that is 2nd input (slot 1) - // of BackpropInput. - Node* filter_node = nullptr; - TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, - &filter_node)); - CHECK_NOTNULL(filter_node); - - // Now check which nodes receive from filter_node. Filter feeds as - // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias. - for (const Edge* e : filter_node->out_edges()) { - if (e->dst()->type_string() == csinfo_.mkl_conv2d && - e->dst_input() == kConv2DFilterInputSlotIdx - /* filter is 2nd input of Conv2D and _MklConv2D. */) { - if (conv2d_node != nullptr) { - VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" - << " feeding multiple Conv2D nodes: " - << filter_node->DebugString(); - // We will not connect filter input of Conv2DBackpropInput - // to be safe here. - do_connect_conv2d_backprop_input_filter = false; - break; - } else { - conv2d_node = e->dst(); - do_connect_conv2d_backprop_input_filter = true; - } - } - } - } - - // Number of input slots to original op - // Input slots are represented by .Input() calls in REGISTER_OP. - int old_node_input_slots = old_node->op_def().input_arg_size(); - // Actual number of inputs can be greater than or equal to number - // of Input slots because inputs of type list could be unfolded. - CHECK_GE(old_node_inputs.size(), old_node_input_slots); - int nn_slot_idx = 0; // slot index for inputs of new node - - // Let's copy all inputs (TF tensors) of original node to new node. - int iidx = 0; - for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { - // An input slot could be a single tensor or a list. We need - // to handle this case accordingly. - CHECK_LT(iidx, old_node_inputs.size()); - const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); - if (ArgIsList(arg)) { - std::vector<NodeBuilder::NodeOut> new_node_inputs; - int N = GetTensorListLength(arg, old_node); - GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, - &new_node_inputs); - nb->Input(new_node_inputs); - nn_slot_idx++; - } else { - // Special case for connecting filter input of Conv2DBackpropInput - if (do_connect_conv2d_backprop_input_filter && - iidx == kConv2DBackpropInputFilterInputSlotIdx) { - nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); - } else { - nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); - } - iidx++; - nn_slot_idx++; - } - } - - // If workspace tensors are available for this op and we are using - // contiguous ordering then we need to add Tensorflow tensor for - // workspace here because Tensorflow tensor for workspace is the - // last tensor in the list of Tensorflow tensors. - if (are_workspace_tensors_available) { - CHECK_EQ(workspace_tensors->size(), 2); - // Tensorflow tensor - nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); - nn_slot_idx++; - } - - // Let's now setup all Mkl inputs to new node. - // Number of Mkl inputs must be same as number of TF inputs. - iidx = 0; - for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { - // An input slot could be a single tensor or a list. We need - // to handle this case accordingly. - CHECK_LT(iidx, old_node_inputs.size()); - const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); - if (ArgIsList(arg)) { - std::vector<NodeBuilder::NodeOut> new_node_inputs; - int N = GetTensorListLength(arg, old_node); - GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N, - &new_node_inputs); - nb->Input(new_node_inputs); - nn_slot_idx++; - } else { - Node* mkl_node = nullptr; - int mkl_node_output_slot = 0; - // Special case for connecting filter input of Conv2DBackpropInput - if (do_connect_conv2d_backprop_input_filter && - iidx == kConv2DBackpropInputFilterInputSlotIdx) { - GetNodeProducingMklTensor(g, old_node, conv2d_node, - kConv2DFilterOutputSlotIdx, &mkl_node, - &mkl_node_output_slot); - } else { - GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, - old_node_inputs[iidx].second, &mkl_node, - &mkl_node_output_slot); - } - nb->Input(mkl_node, mkl_node_output_slot); - iidx++; - nn_slot_idx++; - } - } - - // If workspace tensors are available for this op and we are using - // contiguous ordering then we need to add Mkl tensor for - // workspace here because Mkl tensor for workspace is the - // last tensor in the list of Mkl tensors. - if (are_workspace_tensors_available) { - CHECK_EQ(workspace_tensors->size(), 2); - // Mkl tensor - nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); - nn_slot_idx++; - } - - return nn_slot_idx; -} - -Status MklLayoutRewritePass::SetUpInputs( - std::unique_ptr<Graph>* g, - const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, - NodeBuilder* nb, Node* old_node) { - // Let's check if we need to add workspace tensors for this node. - // We add workspace edge only for MaxPool, LRN and BatchNorm. - std::vector<NodeBuilder::NodeOut> workspace_tensors; - bool are_workspace_tensors_available = false; - AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, - &are_workspace_tensors_available); - - int new_node_input_slots = 0; - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - // TODO(nhasabni): implement this function just for same of completion. - // We do not use interleaved ordering right now. - return Status( - error::Code::UNIMPLEMENTED, - "Interleaved ordering of tensors is currently not supported."); - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - new_node_input_slots = SetUpContiguousInputs( - g, old_node_inputs, nb, old_node, &workspace_tensors, - are_workspace_tensors_available); - } - - // Sanity check - int old_node_input_slots = old_node->op_def().input_arg_size(); - if (!are_workspace_tensors_available) { - // If we are not adding workspace tensors for this op, then the total - // number of input slots to the new node _must_ be 2 times the number - // of input slots to the original node: N original Tensorflow tensors and - // N for Mkl tensors corresponding to each Tensorflow tensors. - CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); - } else { - // If we are adding workspace tensors for this op, then the total - // The total number of input slots to new node _must_ be 2 times the number - // of input slots to the original node: N original Tensorflow tensors and - // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 - // (for workspace Tensorflow tensor and workspace Mkl tensor). - CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); - } - - return Status::OK(); -} - -////////////////////////////////////////////////////////////////////////// -// Helper functions related to workspace pass -////////////////////////////////////////////////////////////////////////// - -// TODO(nhasabni) We should move this to mkl_util.h. -void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( - std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { - // We use a tensor of shape {1} and value 0 to represent - // dummy float tensor. We need this as a dummy workspace tensor. - // Workspace tensor has type float. - const DataType dt = DataTypeToEnum<float>::v(); - TensorProto proto; - proto.set_dtype(dt); - float zero[1] = {0}; - proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 4)); - TensorShape dummy_shape({1}); - dummy_shape.AsProto(proto.mutable_tensor_shape()); - TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const") - .Attr("value", proto) - .Attr("dtype", dt) - .Device(orig_node->def().device()) // We place this node on - // same the device as the - // device of the original - // node. - .Finalize(&**g, out)); - CHECK_NOTNULL(*out); // Make sure we got a valid object before using it - - // If number of inputs to the original node is > 0, then we add - // control dependency between 1st input (index 0) of the original node and - // the dummy Mkl node. This is needed because control-flow ops such as Enter, - // Merge, etc, require frame_name of the dummy Mkl node to be same as the - // rewritten node. Adding control edge between 1st input of the original node - // and the dummy Mkl node ensures that the dummy node is in the same frame - // as the original node. Choosing 1st input is not necessary - any input of - // the original node is fine because all the inputs of a node are always in - // the same frame. - if (orig_node->num_inputs() > 0) { - Node* orig_input0 = nullptr; - TF_CHECK_OK( - orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); - CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); - } - - (*out)->set_assigned_device_name(orig_node->assigned_device_name()); -} - -void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( - std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb, - std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { - bool workspace_edge_added = false; // Default initializer - CHECK_NOTNULL(are_ws_tensors_added); - *are_ws_tensors_added = false; // Default initializer - - DataType T; - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - for (auto ws : wsinfo_) { - if (orig_node->type_string() == ws.fwd_op && - mkl_op_registry::IsMklOp( - mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { - // If this op is a fwd op, then we need to check if there is an - // edge from this node's fwd_slot to bwdop's bwd_slot. If there is - // an edge, then we just add an attribute on this node for setting - // workspace_passed to true. We don't add actual workspace edge - // in this node. Actual workspace edge gets added in the backward - // op for this node. - for (const Edge* e : orig_node->out_edges()) { - if (e->src_output() == ws.fwd_slot && - e->dst()->type_string() == ws.bwd_op && - e->dst_input() == ws.bwd_slot) { - nb->Attr("workspace_enabled", true); - VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " - << orig_node->type_string(); - workspace_edge_added = true; - // We found the edge that we were looking for, so break. - break; - } - } - - if (!workspace_edge_added) { - // If we are here, then we did not find backward operator for this - // node. - nb->Attr("workspace_enabled", false); - } - } else if (orig_node->type_string() == ws.bwd_op && - mkl_op_registry::IsMklOp( - mkl_op_registry::GetMklOpName(orig_node->type_string()), - T)) { - // If this op is a bwd op, then we need to add workspace edge and - // it's Mkl tensor edge between its corresponding fwd op and this - // op. Corresponding fwd op is specified in 'fwd_op' field of - // workspace info. fwd_slot and bwd_slot in workspace info specify - // an edge between which slots connect forward and backward op. - // Once all these criteria match, we add a workspace edge between - // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is - // determined by interleaved/contiguous ordering. Function - // DataIndexToMetaDataIndex tells us the location of Mkl tensor - // from the location of the Tensorflow tensor. - for (const Edge* e : orig_node->in_edges()) { - if (e->src_output() == ws.fwd_slot && - // We would have rewritten the forward op, so we need to use - // GetMklOpName call to get its Mkl name. - e->src()->type_string() == - mkl_op_registry::GetMklOpName(ws.fwd_op) && - e->dst_input() == ws.bwd_slot) { - nb->Attr("workspace_enabled", true); - CHECK_NOTNULL(ws_tensors); - // Add workspace edge between fwd op and bwd op. - ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); - // Add Mkl tensor edge for workspace edge between fwd op and bwd op. - ws_tensors->push_back(NodeBuilder::NodeOut( - e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, - e->src()->num_outputs()))); - *are_ws_tensors_added = true; - // In terms of input ordering, we add these calls to add Input - // here because workspace edge (and its Mkl tensor) is the last - // edge in the fwdop and bwdop. So all inputs before workspace - // tensor have been added by SetUpInputs function. - VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " - << orig_node->type_string(); - workspace_edge_added = true; - // We found the edge that we were looking for, so break. - break; - } - } - - // If we are here means we did not find fwd op that feeds to this - // bwd op. So in this case, we need to generate dummy tensors for - // workspace input and Mkl tensor for workspace, and set - // workspace_enabled to false. - if (!workspace_edge_added) { - nb->Attr("workspace_enabled", false); - Node* dmt_ws = nullptr; // Dummy tensor for workspace - Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace - GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); - GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); - CHECK_NOTNULL(dmt_ws); - CHECK_NOTNULL(dmt_mkl_ws); - CHECK_NOTNULL(ws_tensors); - // We add dummy tensor as workspace tensor. - ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); - // We add dummy tensor as Mkl tensor for workspace tensor. - ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); - *are_ws_tensors_added = true; - VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " - << orig_node->type_string(); - } - } else { - // If this node does not match any workspace info, then we do not - // do anything special for workspace propagation for it. - } - } -} - -////////////////////////////////////////////////////////////////////////// -// Op-specific functions to copy attributes from old node to new node -////////////////////////////////////////////////////////////////////////// - -void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - string data_format; - string padding; - std::vector<int32> strides; - bool use_cudnn_on_gpu; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - TF_CHECK_OK( - GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("strides", strides); - nb->Attr("padding", padding); - nb->Attr("data_format", data_format); - nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); -} - -void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - int N; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("N", N); -} - -void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - string data_format; - std::vector<int32> strides; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("strides", strides); - nb->Attr("data_format", data_format); -} - -void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - // Add attributes to new node. - nb->Attr("T", T); -} - -void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - int depth_radius; - float bias; - float alpha; - float beta; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("depth_radius", depth_radius); - nb->Attr("bias", bias); - nb->Attr("alpha", alpha); - nb->Attr("beta", beta); -} - -void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - string data_format; - string padding; - std::vector<int32> ksize, strides; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("ksize", ksize); - nb->Attr("strides", strides); - nb->Attr("padding", padding); - nb->Attr("data_format", data_format); -} - -void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - - // Add attributes to new node. - nb->Attr("T", T); -} - -void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - DataType Tshape; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape)); - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("Tshape", Tshape); -} - -void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - string data_format; - int num_split; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("num_split", num_split); - nb->Attr("data_format", data_format); -} - -void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - int N; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("N", N); -} - -void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - int N; - DataType tidx; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("N", N); - nb->Attr("Tidx", tidx); -} - -void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node, - NodeBuilder* nb) { - DataType T; - float epsilon; - string data_format; - bool is_training; - - // Get all attributes from old node. - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training)); - - // Add attributes to new node. - nb->Attr("T", T); - nb->Attr("epsilon", epsilon); - nb->Attr("data_format", data_format); - nb->Attr("is_training", is_training); -} - -////////////////////////////////////////////////////////////////////////// -// Helper functions related to node merge pass -////////////////////////////////////////////////////////////////////////// - -Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { - // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite - // once we support BiasAddGrad as Mkl layer. - - // Search for all matching mergeinfo. - // We allow more than one match for extensibility. - std::vector<const MergeInfo*> matching_mi; - for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { - if (a->type_string() == mi->succ) { - matching_mi.push_back(&*mi); - } - } - - for (const MergeInfo* mi : matching_mi) { - const int N_in = a->num_inputs(); - if (mi->op >= N_in) { - continue; - } - - // Get the control edges and input of node - gtl::InlinedVector<Node*, 4> a_control_edges; - gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); - FillInputs(a, &a_control_edges, &a_in); - - // Get operand op of the operator - Node* b = nullptr; - b = a_in[mi->op].first; - if (b == nullptr || (b->type_string() != mi->pred)) { - // NOTE: Should the first check be assert? - continue; - } - - const int B_in = b->num_inputs(); - gtl::InlinedVector<Node*, 4> b_control_edges; - gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); - FillInputs(b, &b_control_edges, &b_in); - - // Shouldn't merge if a and b have different control edges. - if (a_control_edges != b_control_edges) { - continue; - } else { - // We found a match. - return b; - } - } - - return nullptr; -} - -Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ, - Node* pred) { - CHECK_NOTNULL(succ); - CHECK_NOTNULL(pred); - - if (succ->type_string() == csinfo_.bias_add && - pred->type_string() == csinfo_.mkl_conv2d) { - // 1. Get all attributes from input nodes. - DataType T_pred, T_succ; - string padding; - std::vector<int32> strides; - string data_format_pred, data_format_succ; - bool use_cudnn_on_gnu; - TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred)); - TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ)); - TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding)); - TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides)); - TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred)); - TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ)); - TF_CHECK_OK( - GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu)); - // We check to ensure that data formats of both succ and pred are same. - // We expect them to be same, so we can enforce this as assert. - // But assert can be too strict, so we enforce this as a check. - // If the check fails, then we do not merge two nodes. - // We also do same check for devices. - if (data_format_pred != data_format_succ || T_pred != T_succ || - pred->assigned_device_name() != succ->assigned_device_name() || - pred->def().device() != succ->def().device()) { - return Status(error::Code::INVALID_ARGUMENT, - "data_format or T attribute or devices of Conv2D and " - "BiasAdd do not match. Will skip node merge optimization"); - } - - const int succ_num = succ->num_inputs(); - gtl::InlinedVector<Node*, 4> succ_control_edges; - gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); - FillInputs(succ, &succ_control_edges, &succ_in); - - const int pred_num = pred->num_inputs(); - gtl::InlinedVector<Node*, 4> pred_control_edges; - gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); - FillInputs(pred, &pred_control_edges, &pred_in); - - // We need to ensure that there is only 1 edge between Conv2D and AddBias. - // Otherwise, merging is semantically incorrect. - if (pred->out_edges().size() != 1) { - return Status(error::Code::INVALID_ARGUMENT, - "Conv2D has multiple outputs." - "Will skip node merge optimization"); - } - - for (const Edge* e : pred->out_edges()) { - if (e->dst() != succ) { - return Status(error::Code::INVALID_ARGUMENT, - "Conv2D does not feed to BiasAdd." - "Will skip node merge optimization"); - } - } - - // 2. Get inputs from both the nodes. - // Find the 2 inputs from the conv and the bias from the add Bias. - // Get operand 0, 1 of conv2D and their Mkl tensors. - CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs. - // Get operand 1 of add_bias - // BiasAdd must have 2 inputs: Conv, bias - CHECK_EQ(succ->in_edges().size(), 2); - Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3 - int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0. - GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node - // as BiasAdd does not have Mkl tensor as input. - CHECK_NOTNULL(oper3_mkl); - - // We will use the node name of BiasAdd as the name of new node - // Build new node. We use same name as original node, but change the op - // name. - NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias); - if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { - nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D - // pred_in[1] will be Mkl tensor for In1 if we follow interleaved - // ordering, and it will be 2nd Tensorflow tensor for Conv2D if - // we follow contiguous ordering. - nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1 - nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D - nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 - nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd - nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd - } else { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D - // pred_in[1] will be Mkl tensor for In1 if we follow interleaved - // ordering, and it will be 2nd Tensorflow tensor for Conv2D if - // we follow contiguous ordering. - nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D - nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd - nb.Input(pred_in[2].first, pred_in[2].second); // Mkl for In1 of Conv2D - nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 of Conv2D - nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd - } - - // Copy attributes from Conv2D to Conv2DWithBias. - CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); - - // Copy the device assigned to old node to new node. - nb.Device(succ->def().device()); - - // Create node. - Node* new_node; - TF_CHECK_OK(nb.Finalize(&**g, &new_node)); - CHECK_NOTNULL(new_node); - - // Set the Mkl layer label for this op. - new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel); - - // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' - // node are already copied in BuildNode. We handle control edges now. - for (const Edge* e : pred->in_edges()) { - if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); - } - } - for (const Edge* e : succ->in_edges()) { - if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); - } - } - - // Incoming edges are fixed, we will fix the outgoing edges now. - // First, we will fix outgoing control edges from 'pred' node. - // We don't need to handle outgoing data edges from 'pred' node - // because pred has only 1 output going to succ node (we enforced - // this check for merge already). - for (const Edge* e : pred->out_edges()) { - if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); - } - } - - // Second, we will fix outgoing control and data edges from 'succ' node. - for (const Edge* e : succ->out_edges()) { - if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); - } else { - CHECK_NOTNULL( - (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input())); - } - } - - // Copy device assigned to old node to new node. - // It's ok to use pred or succ as we have enforced a check that - // both have same device assigned. - new_node->set_assigned_device_name(pred->assigned_device_name()); - - VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() - << ", and node: " << succ->DebugString() - << ", into node:" << new_node->DebugString(); - - (*g)->RemoveNode(succ); - (*g)->RemoveNode(pred); - - return Status::OK(); - } - - return Status(error::Code::UNIMPLEMENTED, - "Unimplemented case for node merge optimization."); -} - -////////////////////////////////////////////////////////////////////////// -// Helper functions for node rewrite -////////////////////////////////////////////////////////////////////////// - -Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, - Node* orig_node, - const RewriteInfo* ri) { - CHECK_NOTNULL(ri); - CHECK_NOTNULL(orig_node); - - VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); - - // Check if this is scenario 2 (context-based rewrite). - // Get the matching ContextInfo if it is. - const Node* fwd_node = nullptr; - const ContextInfo* ci = nullptr; - bool is_context_based_rewrite = false; - if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) { - is_context_based_rewrite = true; - - // Sanity checks for context-based rewrite (if any) - if (orig_node->type_string() == csinfo_.bias_add_grad && - ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { - CHECK_NOTNULL(fwd_node); - DataType orig_T, ctx_T; - string orig_data_format, ctx_data_format; - TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T)); - TF_CHECK_OK( - GetNodeAttr(orig_node->def(), "data_format", &orig_data_format)); - TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T)); - TF_CHECK_OK( - GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format)); - - if (orig_data_format != ctx_data_format || orig_T != ctx_T || - orig_node->assigned_device_name() != - fwd_node->assigned_device_name() || - orig_node->def().device() != fwd_node->def().device()) { - return Status( - error::Code::INVALID_ARGUMENT, - "data_format or T attribute or devices of BiasAddGrad and " - "Conv2D do not match. Will skip node rewrite optimization"); - } - } else if (orig_node->type_string() == csinfo_.bias_add_grad && - ri->new_name == csinfo_.matmul) { - // When BiasAddGrad has MatMul in context, we do not do any rewrite - // and leave BiasAddGrad as it is. But we check for this condition - // when we check for node rewrite rule. So we should not even come - // here for MatMul. So we will fail now. - return Status( - error::Code::INVALID_ARGUMENT, - "No rewrite is required for BiasAddGrad for MatMul context."); - } - } - - // Get all inputs. - int num_inputs = orig_node->in_edges().size(); - - // Drop count for control edges from inputs - for (const Edge* e : orig_node->in_edges()) { - if (e->IsControlEdge()) { - num_inputs--; - } - } - - gtl::InlinedVector<Node*, 4> control_edges; - gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs); - FillInputs(orig_node, &control_edges, &inputs); - - // Build new node. We use same name as original node, but change the op name. - NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); - // Copy user-specified device assigned to original node to new node. - nb.Device(orig_node->def().device()); - // Set up new inputs to the rewritten node. - Status s = SetUpInputs(g, inputs, &nb, orig_node); - if (s != Status::OK()) { - return s; - } - - // Copy attributes from original node to new node (for scenario 1). - // For context-based rewrite, we use context to copy the attributes. - if (is_context_based_rewrite) { - if (orig_node->type_string() == csinfo_.bias_add_grad && - ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { - CHECK_NOTNULL(fwd_node); - ri->copy_attrs(fwd_node, &nb); - } else { - return Status(error::Code::UNIMPLEMENTED, - "Unimplemented case for node rewrite optimization."); - } - } else { - ri->copy_attrs(const_cast<const Node*>(orig_node), &nb); - } - // Set the Mkl layer label for this op. - nb.Attr("_kernel", mkl_op_registry::kMklOpLabel); - - // Finalize graph and get new node. - Node* new_node = nullptr; - TF_CHECK_OK(nb.Finalize(&**g, &new_node)); - CHECK_NOTNULL(new_node); - - // Incoming data edges from 'orig_node' node to new 'new_node' node are - // already copied in BuildNode. We need to handle control edges now. - for (const Edge* e : orig_node->in_edges()) { - if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); - } - } - - // Copy outgoing edges from 'orig_node' node to new - // 'new_node' node, since the output also follows same ordering among - // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow - // tensors appropriately. Specifically, nth output of the original node - // will become 2*nth output of the Mkl node for the interleaved ordering - // of the tensors. For the contiguous ordering of the tensors, it will be n. - // GetTensorDataIndex provides this mapping function. - for (const Edge* e : orig_node->out_edges()) { - if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); - } else { - CHECK_NOTNULL((*g)->AddEdge( - new_node, - GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), - e->dst(), e->dst_input())); - } - } - - // Copy the runtime device assigned from original code to new node. - new_node->set_assigned_device_name(orig_node->assigned_device_name()); - - // Delete original node and mark new node as rewritten. - (*g)->RemoveNode(orig_node); - - VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); - return Status::OK(); -} - -const MklLayoutRewritePass::ContextInfo* -MklLayoutRewritePass::SearchMatchingContext(const Node* n, - const Node** fwd_node) { - CHECK_NOTNULL(n); - CHECK_NOTNULL(fwd_node); - *fwd_node = nullptr; - - // Search for matching contextinfo based on node name and call - // callback function using matching contextinfo. - // There could be more than one matching contextinfos but whichever - // matches first is returned. - for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { - if (n->type_string() == (*ci)->node && - (*ci)->context_match_fn(n, fwd_node, *ci)) { - VLOG(1) << "Found context as matching: " << (*ci)->fwd; - return *ci; - } - } - return nullptr; -} - -bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n, - const ContextInfo* c) { - const Node* fwd_node = nullptr; - return SearchMatchingContext(n, &fwd_node) == c; -} - -const MklLayoutRewritePass::RewriteInfo* -MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { - CHECK_NOTNULL(n); - - // First check if node along with its type is supported by MKL layer. - // We do not want to rewrite an op into Mkl op if types are not supported. - // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to - // MklRelu if type is INT32. - DataType T; - if (!GetNodeAttr(n->def(), "T", &T).ok()) { - return nullptr; - } - - // BiasAddGrad is not an Mkl layer, so we make an exception for it. - if (n->type_string() != csinfo_.bias_add_grad) { - if (!mkl_op_registry::IsMklOp( - mkl_op_registry::GetMklOpName(n->type_string()), T)) { - return nullptr; - } - } - - // For elementwise node, we reuse the Eigen implementation and pass the MKL - // metadata tensor through so we can avoid conversions. However, if all - // incoming edges are in TF format, we don't need all this overhead, so - // replace the elementwise node only if at least one of its parents is a MKL - // node. - // - // TODO(vrane): Add implementation for element-wise ops that doesn't reuse - // eigen code to reduce cross-library dependency. - if (mkl_op_registry::IsMklElementWiseOp( - mkl_op_registry::GetMklOpName(n->type_string()), T)) { - bool incoming_mkl_edge = false; - for (auto parent : n->in_edges()) { - if (mkl_op_registry::IsMklOp( - mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) { - incoming_mkl_edge = true; - break; - } else { - VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string(); - } - } - if (incoming_mkl_edge == false) { - VLOG(1) << "Skipping replacement of elementwise node which has no MKL " - "parents."; - return nullptr; - } - } - - // We support 2 types of node rewrites: - // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context. - // 2. Rewriting an op to Mkl op always - // We return true if any of these 2 conditions is met. - - // Find matching RewriteInfo and then check that rewrite rule applies. - for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { - if (n->type_string().compare(ri->name) == 0 && - ri->rewrite_rule(n, ri->context)) { - // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context, - // then we just return directly. - if (n->type_string() == csinfo_.bias_add_grad && - ri->context->fwd == csinfo_.matmul && - ri->new_name == csinfo_.bias_add_grad) { - return nullptr; - } - return &*ri; - } - } - - // Else return not found. - return nullptr; -} - -/////////////////////////////////////////////////////////////////////////////// -// Run function for the pass -/////////////////////////////////////////////////////////////////////////////// - -bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { - bool result = false; - CHECK_NOTNULL(g); - - DumpGraph("Before running MklLayoutRewritePass", &**g); - - std::vector<Node*> order; - GetReversePostOrder(**g, &order); // This will give us topological sort. - - for (Node* n : order) { - // If node is not an op or it cannot run on CPU device, then skip. - if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { - continue; - } - - const RewriteInfo* ri = nullptr; - Node* predn = nullptr; - // We will first search if node is to be rewritten - if ((ri = CheckForNodeRewrite(n)) != nullptr) { - string node_name = n->name(); - string op_name = n->type_string(); - - VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name - << " with op " << op_name << " for rewrite using" - << " layout optimization."; - - if (RewriteNode(g, n, ri) == Status::OK()) { - VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name - << " with op " << op_name << " for Mkl layout optimization."; - result = true; - } - } else if ((predn = CheckForNodeMerge(n)) != nullptr) { - // Otherwise, we will check if the node is to be merged. - string n1_name = n->name(); - string n2_name = predn->name(); - - VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " - << n2_name << " for merging"; - - if (MergeNode(g, n, predn) == Status::OK()) { - VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " - << n2_name; - result = true; - } - } - } - - DumpGraph("After running MklLayoutRewritePass", &**g); - - return result; -} - -bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { - return MklLayoutRewritePass().RunPass(g); -} - -Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { - if (options.graph == nullptr && options.partition_graphs == nullptr) { - return Status::OK(); - } - - auto process_graph = [&](std::unique_ptr<Graph>* g) { - // Get the ownership of a graph - std::unique_ptr<Graph>* ng = std::move(g); - RunPass(ng); - // Return the ownership of a graph back - g->reset(ng->release()); - }; - - if (kMklLayoutRewritePassGroup != - OptimizationPassRegistry::POST_PARTITIONING) { - // For any pre-partitioning phase, a graph is stored in options.graph. - process_graph(options.graph); - } else { - // For post partitioning phase, graphs are stored in - // options.partition_graphs. - for (auto& pg : *options.partition_graphs) { - process_graph(&pg.second); - } - } - - return Status::OK(); -} - -#else // INTEL_MKL_ML_ONLY - // This pass implements rewriting of graph to support following scenarios: // (A) Merging nodes in the graph // (B) Rewriting a node in the graph to a new node @@ -4539,7 +2364,7 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { return Status::OK(); } -#endif // INTEL_MKL_ML_ONLY + } // namespace tensorflow #endif diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 77640e287c..0eda8170f8 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -37,1869 +37,6 @@ limitations under the License. namespace tensorflow { -#ifdef INTEL_MKL_ML_ONLY - -namespace { - -const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0"; -const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0"; - -static void InitGraph(const string& s, Graph* graph, - const string& device = kCPUDevice) { - GraphDef graph_def; - - auto parser = protobuf::TextFormat::Parser(); - // parser.AllowRelaxedWhitespace(true); - CHECK(parser.MergeFromString(s, &graph_def)) << s; - GraphConstructorOptions opts; - TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph)); - - for (Node* node : graph->nodes()) { - node->set_assigned_device_name(device); - } -} - -class MklLayoutPassTest : public ::testing::Test { - public: - MklLayoutPassTest() : graph_(OpRegistry::Global()) {} - - void InitGraph(const string& s, const string& device = kCPUDevice) { - ::tensorflow::InitGraph(s, &graph_, device); - original_ = CanonicalGraphString(&graph_); - } - - static bool IncludeNode(const Node* n) { return n->IsOp(); } - - static string EdgeId(const Node* n, int index) { - if (index == 0) { - return n->name(); - } else if (index == Graph::kControlSlot) { - return strings::StrCat(n->name(), ":control"); - } else { - return strings::StrCat(n->name(), ":", index); - } - } - - string CanonicalGraphString(Graph* g) { - std::vector<string> nodes; - std::vector<string> edges; - for (const Node* n : g->nodes()) { - if (IncludeNode(n)) { - nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")")); - } - } - for (const Edge* e : g->edges()) { - if (IncludeNode(e->src()) && IncludeNode(e->dst())) { - edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->", - EdgeId(e->dst(), e->dst_input()))); - } - } - // Canonicalize - std::sort(nodes.begin(), nodes.end()); - std::sort(edges.begin(), edges.end()); - return strings::StrCat(str_util::Join(nodes, ";"), "|", - str_util::Join(edges, ";")); - } - - string DoMklLayoutOptimizationPass() { - string before = CanonicalGraphString(&graph_); - LOG(ERROR) << "Before MKL layout rewrite pass: " << before; - - std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_); - RunMklLayoutRewritePass(ug); - - string result = CanonicalGraphString(&graph_); - LOG(ERROR) << "After MKL layout rewrite pass: " << result; - return result; - } - - const string& OriginalGraph() const { return original_; } - - Graph graph_; - string original_; -}; - -REGISTER_OP("Input").Output("o: float").SetIsStateful(); -REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful(); -REGISTER_OP("HalfInput").Output("o: half").SetIsStateful(); -REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful(); -REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful(); -REGISTER_OP("_MklInput2") - .Output("o: uint8") - .Output("o1: uint8") - .SetIsStateful(); - -///////////////////////////////////////////////////////////////////// -// Unit tests related to node merge optiimization -///////////////////////////////////////////////////////////////////// - -TEST_F(MklLayoutPassTest, Basic) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Zeta);D(Zeta)|" - "A->C;A->D;B->C:1;B->D:1"); -} - -// Test set 1: Conv2D + AddBias - -// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved ordering) -// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous ordering) -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'BiasAdd'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['C', 'D'] }" - "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" - "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->E;" - "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;" - "N->E:4;Y->Z:1"); -} - -// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved) -// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous) -// Test for correct output slots selected -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput2'}" - "node { name: 'N' op: '_MklInput2'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M:1', 'N:1']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'BiasAdd'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['C', 'D'] }" - "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);" - "M(_MklInput2);N(_MklInput2);Y(Input);Z(Zeta)|A->E;" - "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;" - "M:1->E:3;N:1->E:4;Y->Z:1"); -} - -// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y); -// This is a case of node rewrite followed by node merge. -// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D -// with BiasAdd to produce _MklConv2DWithBias. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'BiasAdd'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['C', 'D'] }" - "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|" - "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" - "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;" - "DMT/_2->E:5;E->Z;Y->Z:1"); -} - -// Graph contains only _MklConv2D, no AddBias. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|" - "A->C;B->C:1;M->C:2;N->C:3"); -} - -// _MklConv2D output does not go to BiasAdd. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Input'}" - "node { name: 'F' op: 'BiasAdd'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd. - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);" - "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3"); -} - -// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta). -// Merge should not be done in such case. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Input'}" - "node { name: 'F' op: 'BiasAdd'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D', 'E'] }" // Conv2D has two outputs. - // No merge should happen. - "node { name: 'G' op: 'Zeta'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);" - "G(Zeta);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;" - "E->F:1;E->G:1;M->C:2;N->C:3"); -} - -// data_format attribute value mismatch. Merge should not be done -// in such case. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'BiasAdd'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NHCW' } }" - " input: ['C', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);" - "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3"); -} - -// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias -// rewrite tests - -// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter -// and BackpropInput -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'O' op: '_MklInput'}" - "node { name: 'D' op: '_MklConv2DWithBias'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'A']}" - "node { name: 'F' op: 'Int32Input'}" - "node { name: 'G' op: '_MklConv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" - "node { name: 'H' op: 'Int32Input'}" - "node { name: 'I' op: '_MklConv2DBackpropInput'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['H', 'B', 'E', 'M', 'N', 'O']}" - "node { name: 'J' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" - "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" - "I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);" - "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;" - "B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;" - "E:control->DMT/_0:control;F->G:1;H->I;M->D:3;M->G:3;M->I:3;" - "N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); -} - -// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter -// and BackpropInput. But nodes do not match criteria for rewrite. So -// rewrite should not happen. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'O' op: '_MklInput'}" - "node { name: 'D' op: '_MklConv2DWithBias'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'A']}" - "node { name: 'F' op: 'Int32Input'}" - "node { name: 'G' op: '_MklConv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['E', 'F', 'A', 'M', 'N', 'O'] }" - "node { name: 'H' op: 'Int32Input'}" - "node { name: 'I' op: '_MklConv2DBackpropInput'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['H', 'B', 'E', 'M', 'N', 'O']}" - "node { name: 'J' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" - "I(_MklConv2DBackpropInput);J(BiasAddGrad);" - "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" - "B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" - "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); -} - -// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter -// and BackpropInput. But nodes do not match criteria for rewrite. So -// rewrite should not happen. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'O' op: '_MklInput'}" - "node { name: 'D' op: '_MklConv2DWithBias'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['B', 'A', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'A']}" - "node { name: 'F' op: 'Int32Input'}" - "node { name: 'G' op: '_MklConv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" - "node { name: 'H' op: 'Int32Input'}" - "node { name: 'I' op: '_MklConv2DBackpropInput'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['H', 'B', 'E', 'M', 'N', 'O']}" - "node { name: 'J' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" - "I(_MklConv2DBackpropInput);J(BiasAddGrad);" - "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" - "B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" - "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); -} - -// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'O' op: '_MklInput'}" - "node { name: 'D' op: '_MklConv2DWithBias'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'A']}" - "node { name: 'F' op: 'Int32Input'}" - "node { name: 'G' op: '_MklConv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" - "node { name: 'H' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" - "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);" - "H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);" - "O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;" - "E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;" - "N->D:4;N->G:4;O->D:5;O->G:5"); -} - -// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only -// But BackpropFilter node inputs do not satisfy criteria for rewrite. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'O' op: '_MklInput'}" - "node { name: 'D' op: '_MklConv2DWithBias'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'A']}" - "node { name: 'F' op: 'Int32Input'}" - "node { name: 'G' op: '_MklConv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['E', 'F', 'A', 'M', 'N', 'O'] }" - "node { name: 'H' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" - "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" - "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" - "O->G:5"); -} - -// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only -// But BackpropFilter node inputs do not satisfy criteria for rewrite. -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'O' op: '_MklInput'}" - "node { name: 'D' op: '_MklConv2DWithBias'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['B', 'A', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'A']}" - "node { name: 'F' op: 'Int32Input'}" - "node { name: 'G' op: '_MklConv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" - "node { name: 'H' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" - "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" - "C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" - "O->G:5"); -} - -// No _MklConv2DWithBias in context, but _MklConv2D in context. -// No rewrite for BiasAddGrad should happen. -// C=_MklConv2D(A,M,B,N); D=Zeta(C,A); E=BiasAddGrad(D) (for interleaved) -// C=_MklConv2D(A,B,M,N); D=Zeta(C,A); E=BiasAddGrad(D) (for contiguous) -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'A']}" - "node { name: 'E' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);" - "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;" - "M->C:2;N->C:3"); -} - -// No Conv2D in the context for BiasAddGrad. No rewrite should happen. -// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D) -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Polygamma'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'A']}" - "node { name: 'E' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|" - "A->C;A->D:1;B->C:1;C->D;D->E"); -} - -// No Conv2D in the context for BiasAddGrad, but MatMul in context. -// Rewrite should happen, but name of BiasAddGrad does not change. -// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D) -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'MatMul'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'transpose_a' value { b: false } }" - " attr { key: 'transpose_b' value { b: false } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'A']}" - "node { name: 'E' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|" - "A->C;A->D:1;B->C:1;C->D;D->E"); -} - -// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests -// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D) -TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'MatMul'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'transpose_a' value { b: false } }" - " attr { key: 'transpose_b' value { b: false } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'A']}" - "node { name: 'E' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|" - "A->C;A->D:1;B->C:1;C->D;D->E"); -} - -// No MatMul in the context for BiasAddGrad. No rewrite should happen. -// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D) -TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Polygamma'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'A']}" - "node { name: 'E' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|" - "A->C;A->D:1;B->C:1;C->D;D->E"); -} - -///////////////////////////////////////////////////////////////////// -// Unit tests related to rewriting node to Mkl node -///////////////////////////////////////////////////////////////////// - -// Single Conv2D Op; No Mkl layer on the input and on the output. -// We will generate dummy Mkl tensor as 2nd input of Conv2D. -TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['B', 'C'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);" - "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" - "DMT/_1->C:3"); -} - -// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will -// have 2 outputs, both of which will be inputs to next Conv2D. -TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'C']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;" - "A:control->DMT/_0:control;A:control->DMT/_1:control;" - "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;" - "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2"); -} - -// Conv2D with INT32 which is not supported by Mkl -TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) { - InitGraph( - "node { name: 'A' op: 'HalfInput'}" - "node { name: 'B' op: 'HalfInput'}" - "node { name: 'C' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_HALF } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }" - " input: ['B', 'C'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|" - "A->C;B->C:1;B->D;C->D:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Int32Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Conv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'C']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);" - "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" - "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" - "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" - "DMT/_1->D:4;DMT/_2->D:5"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Int32Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Conv2DBackpropInput'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['B', 'A', 'C']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);" - "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|" - "A->D:1;A->E;B->D;B:control->DMT/_0:control;" - "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;" - "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); -} - -// Concat Op test: Concat with no Mkl layer feeding it -TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) { - InitGraph( - "node { name: 'A' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'B' op: 'InputList'" - " attr { key: 'N' value { i: 2 } }}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Concat'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['A', 'B:0', 'B:1']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }"); - EXPECT_EQ( - DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;" - "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); -} - -// Concat with 2 Mkl layers feeding it -TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'F' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['C', 'D']}" - "node { name: 'G' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'H' op: 'Concat'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['G', 'E', 'F']}" - "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'H'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" - "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;" - "A:control->DMT/_2:control;A:control->DMT/_3:control;" - "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;" - "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" - "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;" - "G:control->DMT/_4:control;H->I:1"); -} - -// Concat with 1 Mkl and 1 non-Mkl layer feeding it -TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D']}" - "node { name: 'G' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'H' op: 'Concat'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['G', 'E', 'F']}" - "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'H'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" - "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" - "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;" - "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1"); -} - -// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it -TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) { - InitGraph( - "node { name: 'A' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'B' op: 'InputList'" - " attr { key: 'N' value { i: 2 } }}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'ConcatV2'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'Tidx' value { type: DT_INT32 } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['B:0', 'B:1', 'A']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;" - "B:control->DMT/_0:control;B:control->DMT/_1:control;" - "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;" - "DMT/_1->D:4;DMT/_2->D:5"); -} - -// ConcatV2 with 2 Mkl layers feeding it -TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'F' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['C', 'D']}" - "node { name: 'G' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'H' op: 'ConcatV2'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'Tidx' value { type: DT_INT32 } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['E', 'F', 'G']}" - "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'H'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);" - "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;" - "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;" - "C:control->DMT/_0:control;C:control->DMT/_1:control;" - "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;" - "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;" - "F:2->H:4;G->H:2;H->I:1"); -} - -// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it -TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D']}" - "node { name: 'G' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'H' op: 'ConcatV2'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'Tidx' value { type: DT_INT32 } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['E', 'F', 'G']}" - "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'H'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);" - "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;" - "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;" - "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;" - "G->H:2;H->I:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Relu'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" - "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'ReluGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" - "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Relu'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A'] }" - "node { name: 'C' op: 'ReluGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);" - "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;" - "DMT/_1->C:2"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'AvgPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;" - "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) { - InitGraph( - "node { name: 'A' op: 'Int32Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'AvgPoolGrad' " - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['A', 'B'] }" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['B', 'C'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" - "DMT/_1(Const)|A->C;A:control->DMT/_0:control;" - "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;" - "DMT/_1->C:3"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'I' op: 'Int32Input'}" - "node { name: 'B' op: 'AvgPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['A'] }" - "node { name: 'C' op: 'AvgPoolGrad' " - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['I', 'B'] }" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);" - "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;" - "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;" - "I:control->DMT/_1:control"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Input'}" - "node { name: 'F' op: 'FusedBatchNormGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'epsilon' value { f: 0.0001 } }" - " attr { key: 'is_training' value { b: true } }" - " input: ['A', 'B', 'C', 'D', 'E'] }" - "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'F'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" - "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;" - "A:control->DMT/_0:control;A:control->DMT/_1:control;" - "A:control->DMT/_2:control;A:control->DMT/_3:control;" - "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" - "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" - "E->F:4;F->G:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Input'}" - "node { name: 'F' op: 'FusedBatchNorm'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'epsilon' value { f: 0.0001 } }" - " attr { key: 'is_training' value { b: true } }" - " input: ['A', 'B', 'C', 'D', 'E'] }" - "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'F'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);" - "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;" - "A:control->DMT/_0:control;A:control->DMT/_1:control;" - "A:control->DMT/_2:control;A:control->DMT/_3:control;" - "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;" - "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;" - "E->F:4;F->G:1"); -} - -///////////////////////////////////////////////////////////////////// -// Unit tests related to rewriting node for workspace edges -///////////////////////////////////////////////////////////////////// - -/* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */ -TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'LRN'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['A'] }" - "node { name: 'C' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['B'] }" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'MaxPoolGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['B', 'C', 'D'] }" - "node { name: 'F' op: 'Input'}" - "node { name: 'G' op: 'LRNGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['E', 'F', 'B'] }" - "node { name: 'H' op: 'Input'}" - "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['H', 'G'] }"); - EXPECT_EQ( - DoMklLayoutOptimizationPass(), - "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);" - "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;" - "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;" - "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;" - "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I"); -} - -/* Test LRN->LRNGrad replacement by workspace nodes. */ -TEST_F(MklLayoutPassTest, LRN_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'LRN'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['A'] }" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'LRNGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['C', 'D', 'B'] }" - "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|" - "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;" - "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" - "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1"); -} - -/* Test LRN->LRNGrad replacement when only one of them is present. */ -TEST_F(MklLayoutPassTest, LRN_Negative1) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'LRN'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|" - "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); -} - -/* Test LRN->LRNGrad replacement when only one of them is present. */ -TEST_F(MklLayoutPassTest, LRN_Negative2) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'LRNGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['A', 'B', 'C'] }" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" - "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" - "A:control->DMT/_2:control;A:control->DMT/_3:control;" - "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" - "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); -} - -/* Test LRN->LRNGrad negative case, where single LRN feeds - 2 LRNGrad nodes at different slots. */ -TEST_F(MklLayoutPassTest, LRN_Negative3) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'LRN'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['A'] }" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'LRNGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['C', 'D', 'B'] }" - "node { name: 'F' op: 'LRNGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'alpha' value { f: 0.001 } }" - " attr { key: 'beta' value { f: 0.75 } }" - " attr { key: 'bias' value { f: 1.0 } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'depth_radius' value { i: 2 } }" - " input: ['C', 'B', 'D'] }" - "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'F'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);" - "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);" - "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;" - "A:control->DMT/_0:control;B->E:2;" - "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;" - "C:control->DMT/_1:control;C:control->DMT/_2:control;" - "C:control->DMT/_3:control;C:control->DMT/_4:control;" - "C:control->DMT/_5:control;C:control->DMT/_6:control;" - "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;" - "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1"); -} - -/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */ -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'MaxPoolGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['C', 'B', 'D'] }" - "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'E'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|" - "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;" - "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;" - "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1"); -} - -// Test MaxPool>MaxPoolGrad replacement when only one of them is present. -// In this case, we will rewrite MaxPool node but workspace edges will not -// be present. -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|" - "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); -} - -// Test MaxPoolGrad replacement when only one of them is present. -// In this case, we will rewrite MaxPoolGrad and for workspace tensor and -// its Mkl part, we will generate dummy tensor. -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'MaxPoolGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }" - " input: ['A', 'B', 'C'] }" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);" - "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|" - "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;" - "A:control->DMT/_2:control;A:control->DMT/_3:control;" - "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;" - "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6"); -} - -// Test MaxPool handling for batch-wise pooling (NCHW) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Test MaxPool handling for batch-wise pooling (NCHW) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Test MaxPool handling for depth-wise pooling (NHWC) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Test MaxPool handling for depth-wise pooling (NCHW) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Test MaxPool handling for batch-wise pooling (NHWC) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NHWC' } }" - " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Test MaxPool handling for batch-wise pooling (NHWC) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NHWC' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Test MaxPool handling for depth-wise pooling (NHWC) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NHWC' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Test MaxPool handling for depth-wise pooling (NHWC) -// No rewrite should take place in such case -TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NHWC' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }"); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -///////////////////////////////////////////////////////////////////// - -// Single Conv2D Op on GPU device -// No rewrite should happen -TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Conv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B']}" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['B', 'C'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1"); -} - -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'O' op: '_MklInput'}" - "node { name: 'D' op: '_MklConv2DWithBias'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'C', 'M', 'N', 'O']}" - "node { name: 'E' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['D', 'A']}" - "node { name: 'F' op: 'BiasAddGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['E'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" - "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);" - "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;" - "M->D:3;N->D:4;O->D:5"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Int32Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Conv2DBackpropFilter'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'C']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'D'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|" - "A->D;A->E;B->D:1;C->D:2;D->E:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Relu'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'ReluGrad'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }" - "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'C'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'MaxPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NHWC' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'AvgPool'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NHWC' } }" - " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'VALID' } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " input: ['A'] }" - "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'B'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1"); -} - -// Concat Op test: Concat with no Mkl layer feeding it -TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'B' op: 'InputList'" - " attr { key: 'N' value { i: 2 } }}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Concat'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['A', 'B:0', 'B:1']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;" - "B->D:1;B:1->D:2;C->E;D->E:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Const' " - " attr { key: 'dtype' value { type: DT_INT32 } }" - " attr { key: 'value' value { " - " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " - " int_val: 0 } } } }" - "node { name: 'B' op: 'InputList'" - " attr { key: 'N' value { i: 2 } }}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'ConcatV2'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'Tidx' value { type: DT_INT32 } }" - " attr { key: 'N' value { i: 2 } }" - " input: ['B:0', 'B:1', 'A']}" - "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['C', 'D'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" - "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); -} - -TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'C' op: 'Input'}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'Input'}" - "node { name: 'F' op: 'FusedBatchNorm'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'epsilon' value { f: 0.0001 } }" - " attr { key: 'is_training' value { b: true } }" - " input: ['A', 'B', 'C', 'D', 'E'] }" - "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" - " input: ['A', 'F'] }", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(Input);D(Input);E(Input);" - "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;" - "E->F:4;F->G:1"); -} - -TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { - CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); - InitGraph( - "node { name: 'A' op: 'Input'}" - "node { name: 'B' op: 'Input'}" - "node { name: 'M' op: '_MklInput'}" - "node { name: 'N' op: '_MklInput'}" - "node { name: 'C' op: '_MklConv2D'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " attr { key: 'use_cudnn_on_gpu' value { b: false } }" - " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" - " attr { key: 'padding' value { s: 'SAME' } }" - " input: ['A', 'B', 'M', 'N']}" - "node { name: 'D' op: 'Input'}" - "node { name: 'E' op: 'BiasAdd'" - " attr { key: 'T' value { type: DT_FLOAT } }" - " attr { key: 'data_format' value { s: 'NCHW' } }" - " input: ['C', 'D'] }" - "node { name: 'Y' op: 'Input'}" - "node { name: 'Z' op: 'Zeta'" - " attr {key: 'T' value { type: DT_FLOAT } }" - " input: ['E', 'Y']}", - kGPUDevice); - EXPECT_EQ(DoMklLayoutOptimizationPass(), - "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);" - "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;" - "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); -} - -///////////////////////////////////////////////////////////////////// - -static void BM_MklLayoutRewritePass(int iters, int op_nodes) { - testing::StopTiming(); - string s; - for (int in = 0; in < 10; in++) { - s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); - } - random::PhiloxRandom philox(301, 17); - random::SimplePhilox rnd(&philox); - for (int op = 0; op < op_nodes; op++) { - s += strings::Printf( - "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { " - "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }", - op, rnd.Uniform(10), rnd.Uniform(10)); - } - - bool first = true; - while (iters > 0) { - Graph* graph = new Graph(OpRegistry::Global()); - InitGraph(s, graph); - int N = graph->num_node_ids(); - if (first) { - testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N)); - first = false; - } - { - testing::StartTiming(); - std::unique_ptr<Graph> ug(graph); - RunMklLayoutRewritePass(&ug); - testing::StopTiming(); - } - iters -= N; // Our benchmark units are individual graph nodes, - // not whole graphs - // delete graph; - } -} -BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); - -} // namespace - -#else // INTEL_MKL_ML_ONLY - // NOTE: Unit tests in this file rely on a topological sorted graph for // printing. But since sibling nodes of a node in the topologically sorted graph // can be printed in different orders, tests may fail if the order in which @@ -3602,8 +1739,6 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); } // namespace -#endif // INTEL_MKL_ML_ONLY - } // namespace tensorflow #endif // INTEL_MKL && ENABLE_MKL |