aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc2177
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc1865
2 files changed, 1 insertions, 4041 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 7394b1cddf..42a35727db 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -45,2181 +45,6 @@ limitations under the License.
namespace tensorflow {
-#ifdef INTEL_MKL_ML_ONLY
-
-// This pass implements rewriting of graph to support following scenarios:
-// (A) Merging nodes in the graph
-// (B) Rewriting a node in the graph to a new node
-// Rewrite happens under following 2 scenarios:
-// 1) Propagating Mkl layout as an additional output tensor
-// (we will loosely call a tensor that carries Mkl layout as Mkl tensor
-// henceforth.) from every Mkl supported NN layer.
-// 2) Context-based rewrite: This is needed in order to optimize
-// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
-// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
-// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
-// This is context-specific optimization, where the context is the
-// forward operator that the BiasAddGrad corresponds to.
-//
-// Example of A : Merging nodes in the graph
-// -----------------------------------------
-// Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
-//
-// O = Conv2D(A, B)
-// P = BiasAdd(O, C)
-//
-// We merge them into Conv2DWithBias as:
-// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
-//
-// The meaning of A_m, B_m and C_m is explained in B.1.
-//
-// Merge rules:
-// - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
-// goes to BiasAdd.
-// - Also, the intersection of attributes of both the nodes must have same
-// values.
-// - Both the nodes must have been assigned to same device (if any).
-//
-// Example of B.1 : Rewriting nodes to Mkl nodes
-// ---------------------------------------------
-// Consider a Relu node. Current definition of Relu node looks like:
-//
-// O = Relu(A)
-//
-// Relu has 1 input (A), and 1 output (O).
-//
-// This rewrite pass will generate a new graph node for Relu (new node is
-// called MklRelu) as:
-//
-// O, O_m = MklRelu(A, A_m)
-//
-// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
-// same as input A of Relu; output O is same as output O of Relu. O_m is the
-// additional output tensor that will be set by MklRelu, and it represents
-// Mkl tensor corresponding to O -- in other words, O_m is some kind of
-// metadata for O. A_m is additional input of Relu, and it represents metadata
-// for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
-// this metadata from previous node in the graph.
-//
-// When a previous node in the graph is an Mkl node, A_m will represent a valid
-// Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
-// a dummy Mkl tensor.
-//
-// Rewriting rules:
-// - Selection of a node for rewriting happens by registering the op type of
-// the node with the rewriting pass. If the op type is not registered, then
-// all nodes of this op type will not be rewritten.
-// - Number of inputs after rewriting:
-// Since for every input Tensorflow tensor, the rewritten node gets Mkl
-// tensor(s), rewritten node gets 2*N inputs, where N is the number of
-// inputs for the original node.
-// - Number of outputs after rewriting:
-// Since for every output Tensorflow tensor, the rewritten node generates
-// Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
-// number of outputs of the original node.
-// - Ordering of Tensorflow tensors and Mkl tensors:
-// Since every rewritten node generates twice the number of inputs and
-// outputs, one could imagine various orderings among Tensorflow tensors
-// and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
-// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
-// in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
-// order. Among N inputs one can get N! permutations.
-//
-// So the question is: which order do we follow? We support 2 types of
-// orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
-// follows an intuitive order where an Mkl tensor follows the
-// corresponding Tensorflow tensor immediately. In the context of the
-// above example, it will be: A, A_m, B, B_m. Note that the ordering rule
-// applies to both the inputs and outputs. Contiguous ordering means
-// all the Tensorflow tensors are contiguous followed by all the Mkl
-// tensors. We use contiguous ordering as default.
-//
-// Graph rewrite algorithm:
-// Algorithm: Graph Rewrite
-// Input: Graph G, Names of the nodes to rewrite and their new names
-// Output: Modified Graph G' if the nodes are modified, G otherwise.
-// Start:
-// N = Topological_Sort(G) // N is a set of nodes in toposort order.
-// foreach node n in N
-// do
-// if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input.
-// then
-// E = set of <incoming edge and its src_output slot> of n
-// E' = {} // a new set of edges for rewritten node
-// foreach <e,s> in E
-// do
-// E' U {<e,s>} // First copy edge which generates Tensorflow
-// // tensor as it is
-// m = Source node of edge e
-// if Is_Rewritten(m) // Did we rewrite this node in this pass?
-// then
-// E' U {<m,s+1>} // If yes, then m will generate an Mkl
-// // tensor as an additional output.
-// else
-// d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy
-// // Mkl tensor.
-// E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot.
-// fi
-// done
-// n' = Build_New_Node(G,new_name,E')
-// Mark_Rewritten(n') // Mark the new node as being rewritten.
-// fi
-// done
-//
-// Explanation:
-// For graph rewrite, we visit nodes of the input graph in the
-// topological sort order. With this ordering, we visit nodes in the
-// top-to-bottom fashion. We need this order because while visiting a
-// node we want that all of its input nodes are visited and rewritten if
-// applicable. This is because if we need to rewrite a given node
-// then all of its input nodes need to be fixed (in other words they
-// cannot be deleted later.)
-//
-// While visiting a node, we first check if the op type of the node is
-// an Mkl op. If it is, then we rewrite that node after constructing
-// new inputs to the node. If the op type of the node is not Mkl op,
-// then we do not rewrite that node.
-//
-// Handling workspace propagation for certain ops:
-//
-// Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
-// passing of a workspace from their respective forward ops. Workspace
-// tensors provide memory for storing results of intermediate operations
-// which are helpful in backward propagation. TensorFlow does not have
-// a notion of a workspace and as a result does not allow producing
-// additional outputs from these forward ops. For these ops, we need
-// to add 2 extra edges between forward ops and their corresponding
-// backward ops - the first extra edge carries a workspace tensor and
-// the second one carries an Mkl tensor for the workspace tensor.
-//
-// Example:
-//
-// Typical graph for MaxPool and its gradient looks like:
-//
-// A = MaxPool(T)
-// B = MaxPoolGrad(X, A, Y)
-//
-// We will transform this graph to propagate the workspace as:
-// (with the contiguous ordering)
-//
-// A, W, A_m, W_m = MklMaxPool(T, T_m)
-// B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
-//
-// Here W is the workspace tensor. Transformed tensor names with the
-// suffix _m are Mkl tensors, and this transformation has been done
-// using the algorithm discussed earlier. The transformation for
-// workspace propagation only adds extra outputs (W, W_m) for a forward
-// op and connects them to the corresponding backward ops.
-//
-// Terms:
-//
-// Forward op name = name of the op in the forward pass
-// where a workspace tensor originates (MaxPool in this example)
-// Backward op name = name of the op in the backward pass that receives
-// a workspace tensor from the forward op (MaxPoolGrad in the example)
-// Slot = Position of the output or input slot that will be
-// used by the workspace tensor (1 for MklMaxPool as W is the 2nd
-// output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
-//
-// Question:
-//
-// How do we associate a backward op to a forward op? There can be more
-// than one op with the exact same name.
-//
-// In this example, we associate MaxPoolGrad with MaxPool. But there
-// could be more than one MaxPool ops. To solve this problem, we look
-// for _direct_ edge between a forward op and a backward op (tensor A is
-// flowing along this edge in the example).
-//
-// How do we transform forward and backward ops when there is no direct
-// edge between them? In such a case, we generate dummy tensors for
-// workspace tensors. For the example, transformation of MaxPool will
-// be exactly same as it would be when there is a direct edge between
-// the forward and the backward op --- it is just that MaxPool won't
-// generate any workspace tensor. For MaxPoolGrad, the transformation
-// will also be same, but instead of connecting W and W_m with the
-// outputs of MaxPool, we will produce dummy tensors for them, and we
-// will set workspace_enabled attribute to false.
-//
-// Example of B.2 : Context-based node rewrite
-// -------------------------------------------
-// Consider BiasAddGrad op as:
-//
-// O = _MklConv2D(A, B, C, A_m, B_m, C_m)
-// P = BiasAddGrad(O)
-//
-// Then we rewrite it as:
-//
-// P = Conv2DWithBiasBackpropBias(O, O_m)
-//
-// Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending
-// on the matching 'context'. The term context is loosely related to which
-// forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then
-// we consider it Conv2D context; if it is MatMul, then it is MatMul context.
-
-class MklLayoutRewritePass : public GraphOptimizationPass {
- public:
- MklLayoutRewritePass() {
- // NOTE: names are alphabetically sorted.
- csinfo_.addn = "AddN";
- csinfo_.avg_pool = "AvgPool";
- csinfo_.avg_pool_grad = "AvgPoolGrad";
- csinfo_.bias_add = "BiasAdd";
- csinfo_.bias_add_grad = "BiasAddGrad";
- csinfo_.concat = "Concat";
- csinfo_.concatv2 = "ConcatV2";
- csinfo_.conv2d = "Conv2D";
- csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
- csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
- csinfo_.fused_batch_norm = "FusedBatchNorm";
- csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
- csinfo_.identity = "Identity";
- csinfo_.lrn = "LRN";
- csinfo_.lrn_grad = "LRNGrad";
- csinfo_.matmul = "MatMul";
- csinfo_.max_pool = "MaxPool";
- csinfo_.max_pool_grad = "MaxPoolGrad";
- csinfo_.mkl_conv2d = "_MklConv2D";
- csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
- csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
- csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
- csinfo_.mkl_conv2d_with_bias_backprop_bias =
- "_MklConv2DWithBiasBackpropBias";
- csinfo_.relu = "Relu";
- csinfo_.relu_grad = "ReluGrad";
- csinfo_.reshape = "Reshape";
- csinfo_.split = "Split";
- // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
- // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
- // MklInputConversion op is added before it.
- csinfo_.add = "Add";
- csinfo_.maximum = "Maximum";
- csinfo_.mul = "Mul";
- csinfo_.squared_difference = "SquaredDifference";
- csinfo_.sub = "Sub";
- // End - element-wise ops. See note above.
-
- // NOTE: names are alphabetically sorted.
- rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
- CopyAttrsAddN, AddNRewrite, nullptr});
- rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
- CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.avg_pool,
- mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
- CopyAttrsPooling, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.avg_pool_grad,
- mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
- CopyAttrsPooling, AlwaysRewrite, nullptr});
- // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
- // on if context contains Conv2D.
- rinfo_.push_back({csinfo_.bias_add_grad,
- csinfo_.mkl_conv2d_with_bias_backprop_bias,
- CopyAttrsBiasAddGrad, ContextMatchRewrite,
- &biasaddgrad_conv2dwithbias_context_});
- // BiasAddGrad gets written into BiasAddGrad depending on if context
- // contains MatMul.
- rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul,
- CopyAttrsBiasAddGrad, ContextMatchRewrite,
- &biasaddgrad_matmul_context_});
- rinfo_.push_back({csinfo_.concat,
- mkl_op_registry::GetMklOpName(csinfo_.concat),
- CopyAttrsConcat, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.concatv2,
- mkl_op_registry::GetMklOpName(csinfo_.concatv2),
- CopyAttrsConcatV2, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.conv2d,
- mkl_op_registry::GetMklOpName(csinfo_.conv2d),
- CopyAttrsConv2D, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.conv2d_grad_filter,
- mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
- CopyAttrsConv2D, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.conv2d_grad_input,
- mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
- CopyAttrsConv2D, AlwaysRewrite, nullptr});
-
- rinfo_.push_back({csinfo_.fused_batch_norm,
- mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
- CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
- rinfo_.push_back(
- {csinfo_.fused_batch_norm_grad,
- mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
- CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.identity,
- mkl_op_registry::GetMklOpName(csinfo_.identity),
- CopyAttrsIdentity, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
- CopyAttrsLRN, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.lrn_grad,
- mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
- CopyAttrsLRN, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.max_pool,
- mkl_op_registry::GetMklOpName(csinfo_.max_pool),
- CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
- rinfo_.push_back({csinfo_.max_pool_grad,
- mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
- CopyAttrsPooling, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.maximum,
- mkl_op_registry::GetMklOpName(csinfo_.maximum),
- CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
- CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
- CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.relu_grad,
- mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
- CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.reshape,
- mkl_op_registry::GetMklOpName(csinfo_.reshape),
- CopyAttrsReshape, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.squared_difference,
- mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
- CopyAttrsDataType, AlwaysRewrite, nullptr});
- rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
- CopyAttrsDataType, AlwaysRewrite, nullptr});
-
- // Add info about which ops to add workspace edge to and the slots.
- wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
- wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
-
- // Add a rule for merging nodes
- minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
- csinfo_.mkl_conv2d_with_bias});
-
- biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
- IsBiasAddGradInMatMulContext};
-
- biasaddgrad_conv2dwithbias_context_ = {
- csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
- IsBiasAddGradInConv2DWithBiasContext};
-
- cinfo_.push_back(&biasaddgrad_matmul_context_);
- cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
- }
-
- // Standard interface to run pass
- Status Run(const GraphOptimizationPassOptions& options);
-
- // Helper function which does most of heavy lifting for rewriting
- // Mkl nodes to propagate Mkl tensor as additional output
- //
- // Extracts common functionality between Run public interface and
- // test interface.
- //
- // @return true, if and only if graph is mutated; false otherwise.
- bool RunPass(std::unique_ptr<Graph>* g);
-
- /// Structure to specify the context information used in a node rewrite rule
- typedef struct {
- string node; // Name of the node to be rewritten
- string fwd; // Name of the node in the forward pass that this node
- // corresponds to
- std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
- } ContextInfo;
-
- /// Structure to specify the name of an original node, its new name after
- /// rewrite, the number of inputs to the original node, the function to
- /// be used to copy attributes for the op, and the rule (if any) which
- /// must hold for rewriting the node
- typedef struct {
- string name; // Original name of op of the node in the graph
- string new_name; // New name of the op of the node in the graph
- // A function handler to copy attributes from an old node to a new node.
- std::function<void(const Node*, NodeBuilder*)> copy_attrs;
- // A rule under which to rewrite this node
- std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule;
- // ContextInfo, if any, to be used for rewrite
- ContextInfo* context;
- } RewriteInfo;
-
- /// Structure to specify a forward op, a backward op, and the slot numbers
- /// in the forward and backward ops where we will add a workspace edge.
- typedef struct {
- string fwd_op; // Name of a forward op in the graph
- string bwd_op; // Name of a backward op in the graph
- int fwd_slot; // Output slot in the forward op node where actual
- // output tensor resides
- int bwd_slot; // Input slot in the backward op node where actual
- // input tensor resides
- int ws_fwd_slot; // Output slot in the forward op node where workspace
- // edge is added
- int ws_bwd_slot; // Input slot in the backward op node where workspace
- // edge is added
- } WorkSpaceInfo;
-
- /// Structure to specify information used in node merge
- typedef struct {
- string pred; // Predecessor node string
- string succ; // Successor node string
- int op; // The operand no the predecessor node corresponds
- // to the successor node
- string new_node; // Name of the node after merge
- } MergeInfo;
-
- /// Structure to store all constant strings
- /// NOTE: names are alphabetically sorted.
- typedef struct {
- string addn;
- string add;
- string avg_pool;
- string avg_pool_grad;
- string bias_add;
- string bias_add_grad;
- string concat;
- string concatv2;
- string conv2d;
- string conv2d_grad_input;
- string conv2d_grad_filter;
- string fused_batch_norm;
- string fused_batch_norm_grad;
- string identity;
- string lrn;
- string lrn_grad;
- string matmul;
- string max_pool;
- string max_pool_grad;
- string maximum;
- string mkl_conv2d;
- string mkl_conv2d_grad_input;
- string mkl_conv2d_grad_filter;
- string mkl_conv2d_with_bias;
- string mkl_conv2d_with_bias_backprop_bias;
- string mul;
- string relu;
- string relu_grad;
- string reshape;
- string split;
- string squared_difference;
- string sub;
- } ConstStringsInfo;
-
- private:
- /// Maintain info about nodes to rewrite
- std::vector<RewriteInfo> rinfo_;
-
- /// Maintain info about nodes to add workspace edge
- std::vector<WorkSpaceInfo> wsinfo_;
-
- /// Maintain info about nodes to be merged
- std::vector<MergeInfo> minfo_;
-
- /// Maintain info about nodes to rewrite
- static std::vector<ContextInfo*> cinfo_;
-
- /// Maintain structure of constant strings
- static ConstStringsInfo csinfo_;
-
- /// Context variables used in referencing rules
- static ContextInfo biasaddgrad_matmul_context_;
- static ContextInfo biasaddgrad_conv2dwithbias_context_;
-
- private:
- // Is OpDef::ArgDef a list type? It could be N * T or list(type).
- // Refer to opdef.proto for details of list type.
- inline bool ArgIsList(const OpDef::ArgDef& arg) const {
- return !arg.type_list_attr().empty() || !arg.number_attr().empty();
- }
-
- // Get length of a list in 'n' if 'arg' is of list type. Refer to
- // description of ArgIsList for definition of list type.
- inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
- CHECK_EQ(ArgIsList(arg), true);
- int N = 0;
- const string attr_name = !arg.type_list_attr().empty()
- ? arg.type_list_attr()
- : arg.number_attr();
- if (!arg.type_list_attr().empty()) {
- std::vector<DataType> value;
- TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
- N = value.size();
- } else {
- TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
- }
- return N;
- }
-
- // Can op represented by node 'n' run on DEVICE_CPU?
- // Op can run on CPU with MKL if the runtime assigned device or the
- // user requested device contains device CPU, or both are empty.
- bool CanOpRunOnCPUDevice(const Node* n) {
- bool result = true;
- string reason;
-
- // Substring that should be checked for in device name for CPU device.
- const char* const kCPUDeviceSubStr = "CPU";
-
- // If Op has been specifically assigned to a non-CPU device, then No.
- if (!n->assigned_device_name().empty() &&
- !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
- result = false;
- reason = "Op has been assigned a runtime device that is not CPU.";
- }
-
- // If user has specifically assigned this op to a non-CPU device, then No.
- if (!n->def().device().empty() &&
- !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
- result = false;
- reason = "User has assigned a device that is not CPU.";
- }
-
- if (result == false) {
- VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
- << n->type_string() << ", reason: " << reason;
- }
-
- // Otherwise Yes.
- return result;
- }
-
- // Return a node that can be merged with input node 'n'
- //
- // @return pointer to the node if we can find such a
- // node. Otherwise, it returns nullptr.
- Node* CheckForNodeMerge(const Node* n) const;
-
- // Merge predecessor node with its successor.
- // Currently, we merge Conv2D with BiasAdd only.
- //
- // Input nodes succ and pred may be deleted if the call to
- // this function is successful. Attempt to use the pointers
- // after the call to function may result in undefined behaviors.
- //
- // @input g - input graph, succ - successor node, pred - predecessor node
- // @return Status::OK(), if merging is successful and supported.
- // Returns appropriate Status error code otherwise.
- // Graph is updated in case nodes are merged. Otherwise, it is
- // not updated.
- Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred);
-
- // Check if the node 'n' has any applicable rewrite rule
- // We check for 2 scenarios for rewrite.
- //
- // @return RewriteInfo* for the applicable rewrite rule
- const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
-
- // Default rewrite rule to be used in scenario 1 for rewrite.
- // @return - true (since we want to always rewrite)
- static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) {
- return true;
- }
-
- // Check if we are performing pooling on depth or batch. If it is, then we
- // do not rewrite MaxPool node to Mkl version.
- // @return - true (if it is not a depth/batch wise pooling case);
- // false otherwise.
- static bool NonDepthBatchWisePoolRewrite(const Node* n,
- const ContextInfo* c) {
- CHECK_NOTNULL(n);
-
- string data_format_str;
- TensorFormat data_format;
- std::vector<int32> ksize, strides;
- CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
- CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
- CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(), true);
- CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
-
- // Condition that specifies non-batch-wise and non-depth-wise pooling.
- if (GetTensorDim(ksize, data_format, 'N') == 1 &&
- GetTensorDim(strides, data_format, 'N') == 1 &&
- GetTensorDim(ksize, data_format, 'C') == 1 &&
- GetTensorDim(strides, data_format, 'C') == 1) {
- return true;
- }
-
- return false;
- }
-
- static bool AddNRewrite(const Node* n, const ContextInfo* c) {
- CHECK_NOTNULL(n);
-
- int num;
- CHECK_EQ(GetNodeAttr(n->def(), "N", &num).ok(), true);
-
- // Condition that specifies non-batch-wise and non-depth-wise pooling.
- if (num == 2) {
- return true;
- }
-
- return false;
- }
- // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node
- // specified in contextinfo 'ci'. Function updates fwd_node to point
- // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias.
- //
- // Association checks for one of the following graphs:
- //
- // Graph A:
- //
- // _ = Conv2DWithBias(F, I, _)
- // ..
- // _ = Conv2DBackpropFilter(F, _, G)
- // _ = Conv2DBackpropInput(_, I, G)
- // _ = BiasAddGrad(G)
- //
- // OR
- //
- // Graph B:
- //
- // _ = Conv2DWithBias(F, _, _)
- // ..
- // _ = Conv2DBackpropFilter(F, _, G)
- // _ = BiasAddGrad(G)
- //
- // Here F, G, and I are graph nodes; _ represents graph nodes that we
- // don't care here.
- //
- // @return - true (if BiasAddGrad is associated with Conv2DWithBias);
- // false otherwise.
- static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n,
- const Node** fwd_node,
- void* ci) {
- CHECK_NOTNULL(n);
- CHECK_NOTNULL(fwd_node);
- CHECK_NOTNULL(ci);
- *fwd_node = nullptr;
-
- CHECK_EQ(n->type_string(), csinfo_.bias_add_grad);
-
- // Get the only 1 input of BiasAddGrad.
- CHECK_EQ(n->num_inputs(), 1);
- const Node* bias_add_grad_inp = nullptr;
- TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp));
- CHECK_NOTNULL(bias_add_grad_inp);
-
- // Check if this input also goes to BackpropFilter and BackpropInput
- // as 3rd input.
- bool found_backprop_input = false;
- bool found_backprop_filter = false;
- Node* backprop_filter_node = nullptr;
- Node* backprop_input_node = nullptr;
-
- for (const Edge* e : bias_add_grad_inp->out_edges()) {
- Node* third_input = nullptr;
- if (e->dst()->type_string() == csinfo_.conv2d_grad_input ||
- e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) {
- // Third input (index 2) of BackpropInput
- TF_CHECK_OK(e->dst()->input_node(2, &third_input));
- // Third input (index 2) of BackpropInput must be same as the input
- // of BiasAddGrad.
- if (third_input == bias_add_grad_inp) {
- found_backprop_input = true;
- backprop_input_node = e->dst();
- }
- }
-
- if (e->dst()->type_string() == csinfo_.conv2d_grad_filter ||
- e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) {
- // Third input (index 2) of BackpropFilter
- TF_CHECK_OK(e->dst()->input_node(2, &third_input));
- // Third input (index 2) of BackpropFilter must be same as the input
- // of BiasAddGrad.
- if (third_input == bias_add_grad_inp) {
- found_backprop_filter = true;
- backprop_filter_node = e->dst();
- }
- }
-
- // If we found both the nodes, then we can stop the search.
- if (found_backprop_input && found_backprop_filter) {
- break;
- }
- }
-
- // If BackpropFilter node is not found, then this is not
- // Conv2DWithBias context. For 2nd graph in the example above, only
- // BackpropFilter would be present.
- if (!found_backprop_filter) {
- return false;
- }
-
- // Otherwise, we found the nodes.
- CHECK_NOTNULL(backprop_filter_node);
- if (found_backprop_input) {
- CHECK_NOTNULL(backprop_input_node);
- }
-
- // Now that we confirmed that this is Conv2DWithBias context, we need to
- // get access to the forward node (Conv2DWithBias). 2nd input of
- // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st
- // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter
- // (This comes from definition of gradient computation for Conv2D).
- if (found_backprop_input) {
- // Graph A in the example.
- Node* second_inp_of_input = nullptr;
- Node* first_inp_of_filter = nullptr;
- TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input));
- TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
- CHECK_NOTNULL(second_inp_of_input);
- CHECK_NOTNULL(first_inp_of_filter);
-
- // Now we need to find out Conv2DWithBias node from these input nodes.
- // Conv2DWithBias node is the node that accepts both the nodes
- // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots.
- for (const Edge* fe : first_inp_of_filter->out_edges()) {
- if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
- fe->dst_input() == 0) {
- for (const Edge* ie : second_inp_of_input->out_edges()) {
- if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
- ie->dst_input() == 1 && fe->dst() == ie->dst()) {
- VLOG(1) << "MklLayoutRewritePass: found "
- << fe->dst()->DebugString()
- << " as the forward node for matching context, backward"
- << " node is: " << n->DebugString();
- *fwd_node = fe->dst();
- return true;
- }
- }
- }
- }
- } else {
- // We did not find BackpropInput, so we work with BackpropFilter only.
- // Graph B in the example.
- Node* first_inp_of_filter = nullptr;
- TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
- CHECK_NOTNULL(first_inp_of_filter);
-
- // Now we need to find out Conv2DWithBias node from first input of
- // BackpropFIlter. Conv2DWithBias node is the node that accepts
- // first_inp_of_filter in 1st input slot.
- for (const Edge* fe : first_inp_of_filter->out_edges()) {
- if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
- fe->dst_input() == 0) {
- VLOG(1) << "MklLayoutRewritePass: found " << fe->dst()->DebugString()
- << " as the forward node for matching context, backward"
- << " node is: " << n->DebugString();
- *fwd_node = fe->dst();
- return true;
- }
- }
- }
-
- return false;
- }
-
- // Is BiasAddGrad node in 'n' is associated with MatMul node
- // specified in contextinfo 'ci'. Function does not update fwd_node.
- //
- // @return - true (if BiasAddGrad is associated with MatMul);
- // false otherwise.
- static bool IsBiasAddGradInMatMulContext(const Node* n, const Node** fwd_node,
- void* ci) {
- return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
- }
-
- // Rewrite rule that uses context-information for matching,
- // used in scenario 2.
- //
- // @input - Node 'n' for which to search for matching context
- // @input - The context 'c' under which to rewrite
- // @return - true if we can rewrite node under context 'c';
- // false otherwise.
- static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
-
- // Helper function that searches the matching contextinfo for the node.
- //
- // @input n - Node (gradient op) whose contextinfo is to be searched,
- // fwd_node - pointer to node from the forward pass that this node
- // belongs to. fwd_node cannot be NULL.
- // @return Matching contextinfo in case a match is found; null otherwise.
- // Also updates *fwd_node with pointer to forward node that this
- // context matches.
- static const ContextInfo* SearchMatchingContext(const Node* n,
- const Node** fwd_node);
-
- // Rewrites input node to a new node specified by its matching rewrite info.
- //
- // Method first searches matching rewrite info for input node and then
- // uses that info to rewrite.
- //
- // Input node may be deleted in case of rewrite. Attempt to use the node
- // after the call can result in undefined behaviors.
- //
- // @input g - input graph, n - Node to be rewritten,
- // ri - matching rewriteinfo
- // @return Status::OK(), if the input node is rewritten;
- // Returns appropriate Status error code otherwise.
- // Graph is updated in case the input node is rewritten.
- // Otherwise, it is not updated.
- Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
-
- // Get nodes that will feed a list of TF tensors to the new
- // node that we are constructing.
- //
- // @input g - input graph,
- // @input inputs - inputs to old node that we are using for constructing
- // new inputs,
- // @input input_idx - the index in the 'inputs' vector pointing to the
- // current input that we have processed so far
- // @output input_idx - index will be incremented by the number of nodes
- // from 'inputs' that are processed
- // @input list_length - The expected length of list of TF tensors
- // @output output_nodes - the list of new nodes creating TF tensors
- //
- // @return None
- void GetNodesProducingTFTensorList(
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- int* input_idx, int list_length,
- std::vector<NodeBuilder::NodeOut>* output_nodes);
-
- // Get nodes that will feed a list of Mkl tensors to the new
- // node that we are constructing.
- //
- // @input g - input graph,
- // @input orig_node - Original node that we are rewriting
- // @input inputs - inputs to old node that we are using for constructing
- // new inputs,
- // @input input_idx - the index in the 'inputs' vector pointing to the
- // current input that we have processed so far
- // @output input_idx - index will be incremented by the number of nodes
- // from 'inputs' that are processed
- // @input list_length - The expected length of list of Mkl tensors
- // @output output_nodes - the list of new nodes creating Mkl tensors
- //
- // @return None
- void GetNodesProducingMklTensorList(
- std::unique_ptr<Graph>* g, Node* orig_node,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- int* input_idx, int list_length,
- std::vector<NodeBuilder::NodeOut>* output_nodes);
-
- // Get a node that will feed an Mkl tensor to the new
- // node that we are constructing. The output node could be (1) 'n'
- // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
- // if 'n' is not an Mkl layer.
- //
- // @input g - input graph,
- // @input orig_node - Original node that we are rewriting,
- // @input n - Node based on which we are creating Mkl node,
- // @input n_output_slot - the output slot of node 'n'
- // which is feeding to the node that we are constructing
- // @output mkl_node - the new node that will feed Mkl tensor
- // @output mkl_node_output_slot - the slot number of mkl_node that
- // will feed the tensor
- // @return None
- void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
- Node* n, int n_output_slot, Node** mkl_node,
- int* mkl_node_output_slot);
-
- // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
- // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
- // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
- // producing workspace edges if 'are_workspace_tensors_available' is true.
- // Otherwise, 'workspace_tensors' is empty vector.
- //
- // For details, refer to 'Ordering of inputs after rewriting' section in the
- // documentation above.
- //
- // Returns Status::OK() if setting up inputs is successful, otherwise
- // returns appropriate status code.
- int SetUpContiguousInputs(
- std::unique_ptr<Graph>* g,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
- NodeBuilder* nb, Node* old_node,
- std::vector<NodeBuilder::NodeOut>* workspace_tensors,
- bool are_workspace_tensors_available);
-
- // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
- // in graph 'g'. Original node is input in 'orig_node'.
- //
- // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
- // section in the documentation above.
- //
- // Returns Status::OK() if setting up inputs is successful, otherwise
- // returns appropriate status code.
- Status SetUpInputs(std::unique_ptr<Graph>* g,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- NodeBuilder* nb, Node* orig_node);
-
- // Add workspace edge on the input or output side of Node 'orig_node' by using
- // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
- // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
- // tensors, if they need to be added, will be set into these tensors.
- // If we set workspace tensors, then are_ws_tensors_added should be true.
- void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
- NodeBuilder* nb,
- std::vector<NodeBuilder::NodeOut>* ws_tensors,
- bool* are_ws_tensors_added);
-
- // Functions specific to operators to copy attributes
- // We need operator-specific function to copy attributes because the framework
- // does not provide any generic function for it.
- // NOTE: names are alphabetically sorted.
- static void CopyAttrsAddN(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
-
- // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
- // using node for original node 'orig_node' and return it in '*out'.
- // TODO(nhasabni) We should move this to mkl_util.h
- void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
- Node* orig_node);
- void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
- Node* orig_node);
-};
-
-MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
-MklLayoutRewritePass::ContextInfo
- MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
-MklLayoutRewritePass::ContextInfo
- MklLayoutRewritePass::biasaddgrad_matmul_context_;
-std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
-
-// We register Mkl rewrite pass for phase 1 in post partitioning group.
-// We register it here so that we get a complete picture of all users of Mkl
-// nodes. Do not change the ordering of the Mkl passes.
-const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
- OptimizationPassRegistry::POST_PARTITIONING;
-#ifdef ENABLE_MKL
-REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
-#endif // ENABLE_MKL
-
-//////////////////////////////////////////////////////////////////////////
-// Helper functions for creating new node
-//////////////////////////////////////////////////////////////////////////
-
-static void FillInputs(const Node* n,
- gtl::InlinedVector<Node*, 4>* control_edges,
- gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
- control_edges->clear();
- for (const Edge* e : n->in_edges()) {
- if (e->IsControlEdge()) {
- control_edges->push_back(e->src());
- } else {
- (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
- }
- }
- std::sort(control_edges->begin(), control_edges->end());
- if (n->op_def().is_commutative()) {
- // For commutative inputs, we sort the input by the input Node*
- // to get a canonical ordering (so that add(a,b) and add(b, a) will
- // hash to the same value if is_commutative is true for 'add').
- std::sort(in->begin(), in->end());
- }
-}
-
-void MklLayoutRewritePass::GetNodesProducingTFTensorList(
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
- int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
- CHECK_LT(*input_idx, inputs.size());
- CHECK_GT(list_length, 0);
- CHECK_NOTNULL(output_nodes);
- output_nodes->reserve(list_length);
-
- while (list_length != 0) {
- CHECK_GT(list_length, 0);
- CHECK_LT(*input_idx, inputs.size());
- Node* n = inputs[*input_idx].first;
- int slot = inputs[*input_idx].second;
- // If input node 'n' is just producing a single tensor at
- // output slot 'slot' then we just add that single node.
- output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
- (*input_idx)++;
- list_length--;
- }
-}
-
-// TODO(nhasabni) We should move this to mkl_util.h.
-void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
- Node** out, Node* orig_node) {
- // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
- // dummy Mkl tensor. 8 = 2*size_t.
- const DataType dt = DataTypeToEnum<uint8>::v();
- TensorProto proto;
- proto.set_dtype(dt);
- uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
- proto.set_tensor_content(string(reinterpret_cast<const char*>(zero), 8));
- TensorShape dummy_shape({8});
- dummy_shape.AsProto(proto.mutable_tensor_shape());
- TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // the same device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
- CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
-
- // If number of inputs to the original node is > 0, then we add
- // control dependency between 1st input (index 0) of the original node and
- // the dummy Mkl node. This is needed because control-flow ops such as Enter,
- // Merge, etc, require frame_name of the dummy Mkl node to be same as the
- // rewritten node. Adding control edge between 1st input of the original node
- // and the dummy Mkl node ensures that the dummy node is in the same frame
- // as the original node. Choosing 1st input is not necessary - any input of
- // the original node is fine because all the inputs of a node are always in
- // the same frame.
- if (orig_node->num_inputs() > 0) {
- Node* orig_input0 = nullptr;
- TF_CHECK_OK(
- orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
- CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
- }
-
- (*out)->set_assigned_device_name(orig_node->assigned_device_name());
-}
-
-void MklLayoutRewritePass::GetNodesProducingMklTensorList(
- std::unique_ptr<Graph>* g, Node* orig_node,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
- int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
- CHECK_LT(*input_idx, inputs.size());
- CHECK_GT(list_length, 0);
- CHECK_NOTNULL(output_nodes);
- output_nodes->reserve(list_length);
-
- while (list_length != 0) {
- CHECK_GT(list_length, 0);
- CHECK_LT(*input_idx, inputs.size());
- Node* n = inputs[*input_idx].first;
- int slot = inputs[*input_idx].second;
- // If 'n' is producing a single tensor, then create a single Mkl tensor
- // node.
- Node* mkl_node = nullptr;
- int mkl_node_output_slot = 0;
- GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
- &mkl_node_output_slot);
- output_nodes->push_back(
- NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
- (*input_idx)++;
- list_length--;
- }
-}
-
-// Get an input node that will feed Mkl tensor to the new
-// node that we are constructing. An input node could be (1) 'n'
-// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
-// if 'n' is not an Mkl layer.
-void MklLayoutRewritePass::GetNodeProducingMklTensor(
- std::unique_ptr<Graph>* g, Node* orig_node, Node* n, int n_output_slot,
- Node** mkl_node, int* mkl_node_output_slot) {
- CHECK_NOTNULL(n);
- CHECK_NOTNULL(mkl_node);
- CHECK_NOTNULL(mkl_node_output_slot);
-
- // If this is an MKL op, then it will create extra output for MKL layout.
- DataType T;
- if (GetNodeAttr(n->def(), "T", &T).ok() &&
- mkl_op_registry::IsMklOp(n->type_string(), T)) {
- // If this is an MKL op, then it will generate an edge that will receive
- // Mkl tensor from a node.
- // output slot number for Mkl tensor would be N+slot number of TensorFlow
- // tensor, where N is total number of TensorFlow tensors.
- *mkl_node = n;
- *mkl_node_output_slot =
- GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
- } else {
- // If we have not visited the node and rewritten it, then we need
- // to create a dummy node that will feed a dummy Mkl tensor to this node.
- // DummyMklTensor node has no input and generates only 1 output
- // (dummy Mkl tensor) as output slot number 0.
- GetDummyMklTensorNode(g, mkl_node, orig_node);
- CHECK_NOTNULL(*mkl_node);
- *mkl_node_output_slot = 0;
- }
-}
-
-int MklLayoutRewritePass::SetUpContiguousInputs(
- std::unique_ptr<Graph>* g,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
- NodeBuilder* nb, Node* old_node,
- std::vector<NodeBuilder::NodeOut>* workspace_tensors,
- bool are_workspace_tensors_available) {
- CHECK_NOTNULL(workspace_tensors);
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
-
- // TODO(nhasabni): Temporary solution to connect filter input of
- // BackpropInput with the converted filter from Conv2D.
- bool do_connect_conv2d_backprop_input_filter = false;
- Node* conv2d_node = nullptr;
- // Filter node is 2nd input (slot index 1) of Conv2D.
- int kConv2DFilterInputSlotIdx = 1;
- int kConv2DBackpropInputFilterInputSlotIdx = 1;
- int kConv2DFilterOutputSlotIdx = 1;
- if (old_node->type_string() == csinfo_.conv2d_grad_input) {
- // We need to find Conv2D node from Conv2DBackpropInput.
- // For that let's first find filter node that is 2nd input (slot 1)
- // of BackpropInput.
- Node* filter_node = nullptr;
- TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
- &filter_node));
- CHECK_NOTNULL(filter_node);
-
- // Now check which nodes receive from filter_node. Filter feeds as
- // 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias.
- for (const Edge* e : filter_node->out_edges()) {
- if (e->dst()->type_string() == csinfo_.mkl_conv2d &&
- e->dst_input() == kConv2DFilterInputSlotIdx
- /* filter is 2nd input of Conv2D and _MklConv2D. */) {
- if (conv2d_node != nullptr) {
- VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
- << " feeding multiple Conv2D nodes: "
- << filter_node->DebugString();
- // We will not connect filter input of Conv2DBackpropInput
- // to be safe here.
- do_connect_conv2d_backprop_input_filter = false;
- break;
- } else {
- conv2d_node = e->dst();
- do_connect_conv2d_backprop_input_filter = true;
- }
- }
- }
- }
-
- // Number of input slots to original op
- // Input slots are represented by .Input() calls in REGISTER_OP.
- int old_node_input_slots = old_node->op_def().input_arg_size();
- // Actual number of inputs can be greater than or equal to number
- // of Input slots because inputs of type list could be unfolded.
- CHECK_GE(old_node_inputs.size(), old_node_input_slots);
- int nn_slot_idx = 0; // slot index for inputs of new node
-
- // Let's copy all inputs (TF tensors) of original node to new node.
- int iidx = 0;
- for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
- // An input slot could be a single tensor or a list. We need
- // to handle this case accordingly.
- CHECK_LT(iidx, old_node_inputs.size());
- const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
- if (ArgIsList(arg)) {
- std::vector<NodeBuilder::NodeOut> new_node_inputs;
- int N = GetTensorListLength(arg, old_node);
- GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
- &new_node_inputs);
- nb->Input(new_node_inputs);
- nn_slot_idx++;
- } else {
- // Special case for connecting filter input of Conv2DBackpropInput
- if (do_connect_conv2d_backprop_input_filter &&
- iidx == kConv2DBackpropInputFilterInputSlotIdx) {
- nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
- } else {
- nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
- }
- iidx++;
- nn_slot_idx++;
- }
- }
-
- // If workspace tensors are available for this op and we are using
- // contiguous ordering then we need to add Tensorflow tensor for
- // workspace here because Tensorflow tensor for workspace is the
- // last tensor in the list of Tensorflow tensors.
- if (are_workspace_tensors_available) {
- CHECK_EQ(workspace_tensors->size(), 2);
- // Tensorflow tensor
- nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
- nn_slot_idx++;
- }
-
- // Let's now setup all Mkl inputs to new node.
- // Number of Mkl inputs must be same as number of TF inputs.
- iidx = 0;
- for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
- // An input slot could be a single tensor or a list. We need
- // to handle this case accordingly.
- CHECK_LT(iidx, old_node_inputs.size());
- const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
- if (ArgIsList(arg)) {
- std::vector<NodeBuilder::NodeOut> new_node_inputs;
- int N = GetTensorListLength(arg, old_node);
- GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, N,
- &new_node_inputs);
- nb->Input(new_node_inputs);
- nn_slot_idx++;
- } else {
- Node* mkl_node = nullptr;
- int mkl_node_output_slot = 0;
- // Special case for connecting filter input of Conv2DBackpropInput
- if (do_connect_conv2d_backprop_input_filter &&
- iidx == kConv2DBackpropInputFilterInputSlotIdx) {
- GetNodeProducingMklTensor(g, old_node, conv2d_node,
- kConv2DFilterOutputSlotIdx, &mkl_node,
- &mkl_node_output_slot);
- } else {
- GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
- old_node_inputs[iidx].second, &mkl_node,
- &mkl_node_output_slot);
- }
- nb->Input(mkl_node, mkl_node_output_slot);
- iidx++;
- nn_slot_idx++;
- }
- }
-
- // If workspace tensors are available for this op and we are using
- // contiguous ordering then we need to add Mkl tensor for
- // workspace here because Mkl tensor for workspace is the
- // last tensor in the list of Mkl tensors.
- if (are_workspace_tensors_available) {
- CHECK_EQ(workspace_tensors->size(), 2);
- // Mkl tensor
- nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
- nn_slot_idx++;
- }
-
- return nn_slot_idx;
-}
-
-Status MklLayoutRewritePass::SetUpInputs(
- std::unique_ptr<Graph>* g,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
- NodeBuilder* nb, Node* old_node) {
- // Let's check if we need to add workspace tensors for this node.
- // We add workspace edge only for MaxPool, LRN and BatchNorm.
- std::vector<NodeBuilder::NodeOut> workspace_tensors;
- bool are_workspace_tensors_available = false;
- AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
- &are_workspace_tensors_available);
-
- int new_node_input_slots = 0;
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- // TODO(nhasabni): implement this function just for same of completion.
- // We do not use interleaved ordering right now.
- return Status(
- error::Code::UNIMPLEMENTED,
- "Interleaved ordering of tensors is currently not supported.");
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- new_node_input_slots = SetUpContiguousInputs(
- g, old_node_inputs, nb, old_node, &workspace_tensors,
- are_workspace_tensors_available);
- }
-
- // Sanity check
- int old_node_input_slots = old_node->op_def().input_arg_size();
- if (!are_workspace_tensors_available) {
- // If we are not adding workspace tensors for this op, then the total
- // number of input slots to the new node _must_ be 2 times the number
- // of input slots to the original node: N original Tensorflow tensors and
- // N for Mkl tensors corresponding to each Tensorflow tensors.
- CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
- } else {
- // If we are adding workspace tensors for this op, then the total
- // The total number of input slots to new node _must_ be 2 times the number
- // of input slots to the original node: N original Tensorflow tensors and
- // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
- // (for workspace Tensorflow tensor and workspace Mkl tensor).
- CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
- }
-
- return Status::OK();
-}
-
-//////////////////////////////////////////////////////////////////////////
-// Helper functions related to workspace pass
-//////////////////////////////////////////////////////////////////////////
-
-// TODO(nhasabni) We should move this to mkl_util.h.
-void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
- std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
- // We use a tensor of shape {1} and value 0 to represent
- // dummy float tensor. We need this as a dummy workspace tensor.
- // Workspace tensor has type float.
- const DataType dt = DataTypeToEnum<float>::v();
- TensorProto proto;
- proto.set_dtype(dt);
- float zero[1] = {0};
- proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 4));
- TensorShape dummy_shape({1});
- dummy_shape.AsProto(proto.mutable_tensor_shape());
- TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // same the device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
- CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
-
- // If number of inputs to the original node is > 0, then we add
- // control dependency between 1st input (index 0) of the original node and
- // the dummy Mkl node. This is needed because control-flow ops such as Enter,
- // Merge, etc, require frame_name of the dummy Mkl node to be same as the
- // rewritten node. Adding control edge between 1st input of the original node
- // and the dummy Mkl node ensures that the dummy node is in the same frame
- // as the original node. Choosing 1st input is not necessary - any input of
- // the original node is fine because all the inputs of a node are always in
- // the same frame.
- if (orig_node->num_inputs() > 0) {
- Node* orig_input0 = nullptr;
- TF_CHECK_OK(
- orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
- CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
- }
-
- (*out)->set_assigned_device_name(orig_node->assigned_device_name());
-}
-
-void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
- std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
- std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
- bool workspace_edge_added = false; // Default initializer
- CHECK_NOTNULL(are_ws_tensors_added);
- *are_ws_tensors_added = false; // Default initializer
-
- DataType T;
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- for (auto ws : wsinfo_) {
- if (orig_node->type_string() == ws.fwd_op &&
- mkl_op_registry::IsMklOp(
- mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
- // If this op is a fwd op, then we need to check if there is an
- // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
- // an edge, then we just add an attribute on this node for setting
- // workspace_passed to true. We don't add actual workspace edge
- // in this node. Actual workspace edge gets added in the backward
- // op for this node.
- for (const Edge* e : orig_node->out_edges()) {
- if (e->src_output() == ws.fwd_slot &&
- e->dst()->type_string() == ws.bwd_op &&
- e->dst_input() == ws.bwd_slot) {
- nb->Attr("workspace_enabled", true);
- VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
- << orig_node->type_string();
- workspace_edge_added = true;
- // We found the edge that we were looking for, so break.
- break;
- }
- }
-
- if (!workspace_edge_added) {
- // If we are here, then we did not find backward operator for this
- // node.
- nb->Attr("workspace_enabled", false);
- }
- } else if (orig_node->type_string() == ws.bwd_op &&
- mkl_op_registry::IsMklOp(
- mkl_op_registry::GetMklOpName(orig_node->type_string()),
- T)) {
- // If this op is a bwd op, then we need to add workspace edge and
- // it's Mkl tensor edge between its corresponding fwd op and this
- // op. Corresponding fwd op is specified in 'fwd_op' field of
- // workspace info. fwd_slot and bwd_slot in workspace info specify
- // an edge between which slots connect forward and backward op.
- // Once all these criteria match, we add a workspace edge between
- // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
- // determined by interleaved/contiguous ordering. Function
- // DataIndexToMetaDataIndex tells us the location of Mkl tensor
- // from the location of the Tensorflow tensor.
- for (const Edge* e : orig_node->in_edges()) {
- if (e->src_output() == ws.fwd_slot &&
- // We would have rewritten the forward op, so we need to use
- // GetMklOpName call to get its Mkl name.
- e->src()->type_string() ==
- mkl_op_registry::GetMklOpName(ws.fwd_op) &&
- e->dst_input() == ws.bwd_slot) {
- nb->Attr("workspace_enabled", true);
- CHECK_NOTNULL(ws_tensors);
- // Add workspace edge between fwd op and bwd op.
- ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
- // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
- ws_tensors->push_back(NodeBuilder::NodeOut(
- e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
- e->src()->num_outputs())));
- *are_ws_tensors_added = true;
- // In terms of input ordering, we add these calls to add Input
- // here because workspace edge (and its Mkl tensor) is the last
- // edge in the fwdop and bwdop. So all inputs before workspace
- // tensor have been added by SetUpInputs function.
- VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
- << orig_node->type_string();
- workspace_edge_added = true;
- // We found the edge that we were looking for, so break.
- break;
- }
- }
-
- // If we are here means we did not find fwd op that feeds to this
- // bwd op. So in this case, we need to generate dummy tensors for
- // workspace input and Mkl tensor for workspace, and set
- // workspace_enabled to false.
- if (!workspace_edge_added) {
- nb->Attr("workspace_enabled", false);
- Node* dmt_ws = nullptr; // Dummy tensor for workspace
- Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
- GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
- GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
- CHECK_NOTNULL(dmt_ws);
- CHECK_NOTNULL(dmt_mkl_ws);
- CHECK_NOTNULL(ws_tensors);
- // We add dummy tensor as workspace tensor.
- ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
- // We add dummy tensor as Mkl tensor for workspace tensor.
- ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
- *are_ws_tensors_added = true;
- VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
- << orig_node->type_string();
- }
- } else {
- // If this node does not match any workspace info, then we do not
- // do anything special for workspace propagation for it.
- }
- }
-}
-
-//////////////////////////////////////////////////////////////////////////
-// Op-specific functions to copy attributes from old node to new node
-//////////////////////////////////////////////////////////////////////////
-
-void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- string data_format;
- string padding;
- std::vector<int32> strides;
- bool use_cudnn_on_gpu;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
- TF_CHECK_OK(
- GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("strides", strides);
- nb->Attr("padding", padding);
- nb->Attr("data_format", data_format);
- nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
-}
-
-void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- int N;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("N", N);
-}
-
-void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- string data_format;
- std::vector<int32> strides;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("strides", strides);
- nb->Attr("data_format", data_format);
-}
-
-void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- // Add attributes to new node.
- nb->Attr("T", T);
-}
-
-void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- int depth_radius;
- float bias;
- float alpha;
- float beta;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("depth_radius", depth_radius);
- nb->Attr("bias", bias);
- nb->Attr("alpha", alpha);
- nb->Attr("beta", beta);
-}
-
-void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- string data_format;
- string padding;
- std::vector<int32> ksize, strides;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("ksize", ksize);
- nb->Attr("strides", strides);
- nb->Attr("padding", padding);
- nb->Attr("data_format", data_format);
-}
-
-void MklLayoutRewritePass::CopyAttrsDataType(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
-
- // Add attributes to new node.
- nb->Attr("T", T);
-}
-
-void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- DataType Tshape;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("Tshape", Tshape);
-}
-
-void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- string data_format;
- int num_split;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("num_split", num_split);
- nb->Attr("data_format", data_format);
-}
-
-void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- int N;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("N", N);
-}
-
-void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- int N;
- DataType tidx;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("N", N);
- nb->Attr("Tidx", tidx);
-}
-
-void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- float epsilon;
- string data_format;
- bool is_training;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("epsilon", epsilon);
- nb->Attr("data_format", data_format);
- nb->Attr("is_training", is_training);
-}
-
-//////////////////////////////////////////////////////////////////////////
-// Helper functions related to node merge pass
-//////////////////////////////////////////////////////////////////////////
-
-Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
- // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
- // once we support BiasAddGrad as Mkl layer.
-
- // Search for all matching mergeinfo.
- // We allow more than one match for extensibility.
- std::vector<const MergeInfo*> matching_mi;
- for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
- if (a->type_string() == mi->succ) {
- matching_mi.push_back(&*mi);
- }
- }
-
- for (const MergeInfo* mi : matching_mi) {
- const int N_in = a->num_inputs();
- if (mi->op >= N_in) {
- continue;
- }
-
- // Get the control edges and input of node
- gtl::InlinedVector<Node*, 4> a_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
- FillInputs(a, &a_control_edges, &a_in);
-
- // Get operand op of the operator
- Node* b = nullptr;
- b = a_in[mi->op].first;
- if (b == nullptr || (b->type_string() != mi->pred)) {
- // NOTE: Should the first check be assert?
- continue;
- }
-
- const int B_in = b->num_inputs();
- gtl::InlinedVector<Node*, 4> b_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
- FillInputs(b, &b_control_edges, &b_in);
-
- // Shouldn't merge if a and b have different control edges.
- if (a_control_edges != b_control_edges) {
- continue;
- } else {
- // We found a match.
- return b;
- }
- }
-
- return nullptr;
-}
-
-Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
- Node* pred) {
- CHECK_NOTNULL(succ);
- CHECK_NOTNULL(pred);
-
- if (succ->type_string() == csinfo_.bias_add &&
- pred->type_string() == csinfo_.mkl_conv2d) {
- // 1. Get all attributes from input nodes.
- DataType T_pred, T_succ;
- string padding;
- std::vector<int32> strides;
- string data_format_pred, data_format_succ;
- bool use_cudnn_on_gnu;
- TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
- TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
- TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
- TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
- TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
- TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
- TF_CHECK_OK(
- GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
- // We check to ensure that data formats of both succ and pred are same.
- // We expect them to be same, so we can enforce this as assert.
- // But assert can be too strict, so we enforce this as a check.
- // If the check fails, then we do not merge two nodes.
- // We also do same check for devices.
- if (data_format_pred != data_format_succ || T_pred != T_succ ||
- pred->assigned_device_name() != succ->assigned_device_name() ||
- pred->def().device() != succ->def().device()) {
- return Status(error::Code::INVALID_ARGUMENT,
- "data_format or T attribute or devices of Conv2D and "
- "BiasAdd do not match. Will skip node merge optimization");
- }
-
- const int succ_num = succ->num_inputs();
- gtl::InlinedVector<Node*, 4> succ_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
- FillInputs(succ, &succ_control_edges, &succ_in);
-
- const int pred_num = pred->num_inputs();
- gtl::InlinedVector<Node*, 4> pred_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
- FillInputs(pred, &pred_control_edges, &pred_in);
-
- // We need to ensure that there is only 1 edge between Conv2D and AddBias.
- // Otherwise, merging is semantically incorrect.
- if (pred->out_edges().size() != 1) {
- return Status(error::Code::INVALID_ARGUMENT,
- "Conv2D has multiple outputs."
- "Will skip node merge optimization");
- }
-
- for (const Edge* e : pred->out_edges()) {
- if (e->dst() != succ) {
- return Status(error::Code::INVALID_ARGUMENT,
- "Conv2D does not feed to BiasAdd."
- "Will skip node merge optimization");
- }
- }
-
- // 2. Get inputs from both the nodes.
- // Find the 2 inputs from the conv and the bias from the add Bias.
- // Get operand 0, 1 of conv2D and their Mkl tensors.
- CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs.
- // Get operand 1 of add_bias
- // BiasAdd must have 2 inputs: Conv, bias
- CHECK_EQ(succ->in_edges().size(), 2);
- Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
- int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
- GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node
- // as BiasAdd does not have Mkl tensor as input.
- CHECK_NOTNULL(oper3_mkl);
-
- // We will use the node name of BiasAdd as the name of new node
- // Build new node. We use same name as original node, but change the op
- // name.
- NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias);
- if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
- nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
- // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
- // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
- // we follow contiguous ordering.
- nb.Input(pred_in[1].first, pred_in[1].second); // Mkl for In1
- nb.Input(pred_in[2].first, pred_in[2].second); // In2 of Conv2D
- nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2
- nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
- nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd
- } else {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
- // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
- // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
- // we follow contiguous ordering.
- nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D
- nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
- nb.Input(pred_in[2].first, pred_in[2].second); // Mkl for In1 of Conv2D
- nb.Input(pred_in[3].first, pred_in[3].second); // Mkl for In2 of Conv2D
- nb.Input(oper3_mkl, oper3_mkl_slot); // Mkl for In2 of BiasAdd
- }
-
- // Copy attributes from Conv2D to Conv2DWithBias.
- CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
-
- // Copy the device assigned to old node to new node.
- nb.Device(succ->def().device());
-
- // Create node.
- Node* new_node;
- TF_CHECK_OK(nb.Finalize(&**g, &new_node));
- CHECK_NOTNULL(new_node);
-
- // Set the Mkl layer label for this op.
- new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
-
- // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
- // node are already copied in BuildNode. We handle control edges now.
- for (const Edge* e : pred->in_edges()) {
- if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
- }
- }
- for (const Edge* e : succ->in_edges()) {
- if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
- }
- }
-
- // Incoming edges are fixed, we will fix the outgoing edges now.
- // First, we will fix outgoing control edges from 'pred' node.
- // We don't need to handle outgoing data edges from 'pred' node
- // because pred has only 1 output going to succ node (we enforced
- // this check for merge already).
- for (const Edge* e : pred->out_edges()) {
- if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
- }
- }
-
- // Second, we will fix outgoing control and data edges from 'succ' node.
- for (const Edge* e : succ->out_edges()) {
- if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
- } else {
- CHECK_NOTNULL(
- (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input()));
- }
- }
-
- // Copy device assigned to old node to new node.
- // It's ok to use pred or succ as we have enforced a check that
- // both have same device assigned.
- new_node->set_assigned_device_name(pred->assigned_device_name());
-
- VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
- << ", and node: " << succ->DebugString()
- << ", into node:" << new_node->DebugString();
-
- (*g)->RemoveNode(succ);
- (*g)->RemoveNode(pred);
-
- return Status::OK();
- }
-
- return Status(error::Code::UNIMPLEMENTED,
- "Unimplemented case for node merge optimization.");
-}
-
-//////////////////////////////////////////////////////////////////////////
-// Helper functions for node rewrite
-//////////////////////////////////////////////////////////////////////////
-
-Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
- Node* orig_node,
- const RewriteInfo* ri) {
- CHECK_NOTNULL(ri);
- CHECK_NOTNULL(orig_node);
-
- VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
-
- // Check if this is scenario 2 (context-based rewrite).
- // Get the matching ContextInfo if it is.
- const Node* fwd_node = nullptr;
- const ContextInfo* ci = nullptr;
- bool is_context_based_rewrite = false;
- if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
- is_context_based_rewrite = true;
-
- // Sanity checks for context-based rewrite (if any)
- if (orig_node->type_string() == csinfo_.bias_add_grad &&
- ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
- CHECK_NOTNULL(fwd_node);
- DataType orig_T, ctx_T;
- string orig_data_format, ctx_data_format;
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
- TF_CHECK_OK(
- GetNodeAttr(orig_node->def(), "data_format", &orig_data_format));
- TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T));
- TF_CHECK_OK(
- GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format));
-
- if (orig_data_format != ctx_data_format || orig_T != ctx_T ||
- orig_node->assigned_device_name() !=
- fwd_node->assigned_device_name() ||
- orig_node->def().device() != fwd_node->def().device()) {
- return Status(
- error::Code::INVALID_ARGUMENT,
- "data_format or T attribute or devices of BiasAddGrad and "
- "Conv2D do not match. Will skip node rewrite optimization");
- }
- } else if (orig_node->type_string() == csinfo_.bias_add_grad &&
- ri->new_name == csinfo_.matmul) {
- // When BiasAddGrad has MatMul in context, we do not do any rewrite
- // and leave BiasAddGrad as it is. But we check for this condition
- // when we check for node rewrite rule. So we should not even come
- // here for MatMul. So we will fail now.
- return Status(
- error::Code::INVALID_ARGUMENT,
- "No rewrite is required for BiasAddGrad for MatMul context.");
- }
- }
-
- // Get all inputs.
- int num_inputs = orig_node->in_edges().size();
-
- // Drop count for control edges from inputs
- for (const Edge* e : orig_node->in_edges()) {
- if (e->IsControlEdge()) {
- num_inputs--;
- }
- }
-
- gtl::InlinedVector<Node*, 4> control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
- FillInputs(orig_node, &control_edges, &inputs);
-
- // Build new node. We use same name as original node, but change the op name.
- NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
- // Copy user-specified device assigned to original node to new node.
- nb.Device(orig_node->def().device());
- // Set up new inputs to the rewritten node.
- Status s = SetUpInputs(g, inputs, &nb, orig_node);
- if (s != Status::OK()) {
- return s;
- }
-
- // Copy attributes from original node to new node (for scenario 1).
- // For context-based rewrite, we use context to copy the attributes.
- if (is_context_based_rewrite) {
- if (orig_node->type_string() == csinfo_.bias_add_grad &&
- ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
- CHECK_NOTNULL(fwd_node);
- ri->copy_attrs(fwd_node, &nb);
- } else {
- return Status(error::Code::UNIMPLEMENTED,
- "Unimplemented case for node rewrite optimization.");
- }
- } else {
- ri->copy_attrs(const_cast<const Node*>(orig_node), &nb);
- }
- // Set the Mkl layer label for this op.
- nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
-
- // Finalize graph and get new node.
- Node* new_node = nullptr;
- TF_CHECK_OK(nb.Finalize(&**g, &new_node));
- CHECK_NOTNULL(new_node);
-
- // Incoming data edges from 'orig_node' node to new 'new_node' node are
- // already copied in BuildNode. We need to handle control edges now.
- for (const Edge* e : orig_node->in_edges()) {
- if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
- }
- }
-
- // Copy outgoing edges from 'orig_node' node to new
- // 'new_node' node, since the output also follows same ordering among
- // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
- // tensors appropriately. Specifically, nth output of the original node
- // will become 2*nth output of the Mkl node for the interleaved ordering
- // of the tensors. For the contiguous ordering of the tensors, it will be n.
- // GetTensorDataIndex provides this mapping function.
- for (const Edge* e : orig_node->out_edges()) {
- if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
- } else {
- CHECK_NOTNULL((*g)->AddEdge(
- new_node,
- GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
- e->dst(), e->dst_input()));
- }
- }
-
- // Copy the runtime device assigned from original code to new node.
- new_node->set_assigned_device_name(orig_node->assigned_device_name());
-
- // Delete original node and mark new node as rewritten.
- (*g)->RemoveNode(orig_node);
-
- VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
- return Status::OK();
-}
-
-const MklLayoutRewritePass::ContextInfo*
-MklLayoutRewritePass::SearchMatchingContext(const Node* n,
- const Node** fwd_node) {
- CHECK_NOTNULL(n);
- CHECK_NOTNULL(fwd_node);
- *fwd_node = nullptr;
-
- // Search for matching contextinfo based on node name and call
- // callback function using matching contextinfo.
- // There could be more than one matching contextinfos but whichever
- // matches first is returned.
- for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
- if (n->type_string() == (*ci)->node &&
- (*ci)->context_match_fn(n, fwd_node, *ci)) {
- VLOG(1) << "Found context as matching: " << (*ci)->fwd;
- return *ci;
- }
- }
- return nullptr;
-}
-
-bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n,
- const ContextInfo* c) {
- const Node* fwd_node = nullptr;
- return SearchMatchingContext(n, &fwd_node) == c;
-}
-
-const MklLayoutRewritePass::RewriteInfo*
-MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
- CHECK_NOTNULL(n);
-
- // First check if node along with its type is supported by MKL layer.
- // We do not want to rewrite an op into Mkl op if types are not supported.
- // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
- // MklRelu if type is INT32.
- DataType T;
- if (!GetNodeAttr(n->def(), "T", &T).ok()) {
- return nullptr;
- }
-
- // BiasAddGrad is not an Mkl layer, so we make an exception for it.
- if (n->type_string() != csinfo_.bias_add_grad) {
- if (!mkl_op_registry::IsMklOp(
- mkl_op_registry::GetMklOpName(n->type_string()), T)) {
- return nullptr;
- }
- }
-
- // For elementwise node, we reuse the Eigen implementation and pass the MKL
- // metadata tensor through so we can avoid conversions. However, if all
- // incoming edges are in TF format, we don't need all this overhead, so
- // replace the elementwise node only if at least one of its parents is a MKL
- // node.
- //
- // TODO(vrane): Add implementation for element-wise ops that doesn't reuse
- // eigen code to reduce cross-library dependency.
- if (mkl_op_registry::IsMklElementWiseOp(
- mkl_op_registry::GetMklOpName(n->type_string()), T)) {
- bool incoming_mkl_edge = false;
- for (auto parent : n->in_edges()) {
- if (mkl_op_registry::IsMklOp(
- mkl_op_registry::GetMklOpName(parent->src()->type_string()), T)) {
- incoming_mkl_edge = true;
- break;
- } else {
- VLOG(1) << "Non-MKL parent is: " << parent->src()->type_string();
- }
- }
- if (incoming_mkl_edge == false) {
- VLOG(1) << "Skipping replacement of elementwise node which has no MKL "
- "parents.";
- return nullptr;
- }
- }
-
- // We support 2 types of node rewrites:
- // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context.
- // 2. Rewriting an op to Mkl op always
- // We return true if any of these 2 conditions is met.
-
- // Find matching RewriteInfo and then check that rewrite rule applies.
- for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
- if (n->type_string().compare(ri->name) == 0 &&
- ri->rewrite_rule(n, ri->context)) {
- // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context,
- // then we just return directly.
- if (n->type_string() == csinfo_.bias_add_grad &&
- ri->context->fwd == csinfo_.matmul &&
- ri->new_name == csinfo_.bias_add_grad) {
- return nullptr;
- }
- return &*ri;
- }
- }
-
- // Else return not found.
- return nullptr;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-// Run function for the pass
-///////////////////////////////////////////////////////////////////////////////
-
-bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
- bool result = false;
- CHECK_NOTNULL(g);
-
- DumpGraph("Before running MklLayoutRewritePass", &**g);
-
- std::vector<Node*> order;
- GetReversePostOrder(**g, &order); // This will give us topological sort.
-
- for (Node* n : order) {
- // If node is not an op or it cannot run on CPU device, then skip.
- if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
- continue;
- }
-
- const RewriteInfo* ri = nullptr;
- Node* predn = nullptr;
- // We will first search if node is to be rewritten
- if ((ri = CheckForNodeRewrite(n)) != nullptr) {
- string node_name = n->name();
- string op_name = n->type_string();
-
- VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
- << " with op " << op_name << " for rewrite using"
- << " layout optimization.";
-
- if (RewriteNode(g, n, ri) == Status::OK()) {
- VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
- << " with op " << op_name << " for Mkl layout optimization.";
- result = true;
- }
- } else if ((predn = CheckForNodeMerge(n)) != nullptr) {
- // Otherwise, we will check if the node is to be merged.
- string n1_name = n->name();
- string n2_name = predn->name();
-
- VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
- << n2_name << " for merging";
-
- if (MergeNode(g, n, predn) == Status::OK()) {
- VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
- << n2_name;
- result = true;
- }
- }
- }
-
- DumpGraph("After running MklLayoutRewritePass", &**g);
-
- return result;
-}
-
-bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
- return MklLayoutRewritePass().RunPass(g);
-}
-
-Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
- if (options.graph == nullptr && options.partition_graphs == nullptr) {
- return Status::OK();
- }
-
- auto process_graph = [&](std::unique_ptr<Graph>* g) {
- // Get the ownership of a graph
- std::unique_ptr<Graph>* ng = std::move(g);
- RunPass(ng);
- // Return the ownership of a graph back
- g->reset(ng->release());
- };
-
- if (kMklLayoutRewritePassGroup !=
- OptimizationPassRegistry::POST_PARTITIONING) {
- // For any pre-partitioning phase, a graph is stored in options.graph.
- process_graph(options.graph);
- } else {
- // For post partitioning phase, graphs are stored in
- // options.partition_graphs.
- for (auto& pg : *options.partition_graphs) {
- process_graph(&pg.second);
- }
- }
-
- return Status::OK();
-}
-
-#else // INTEL_MKL_ML_ONLY
-
// This pass implements rewriting of graph to support following scenarios:
// (A) Merging nodes in the graph
// (B) Rewriting a node in the graph to a new node
@@ -4539,7 +2364,7 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
return Status::OK();
}
-#endif // INTEL_MKL_ML_ONLY
+
} // namespace tensorflow
#endif
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 77640e287c..0eda8170f8 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -37,1869 +37,6 @@ limitations under the License.
namespace tensorflow {
-#ifdef INTEL_MKL_ML_ONLY
-
-namespace {
-
-const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
-const char kGPUDevice[] = "/job:a/replica:0/task:0/device:GPU:0";
-
-static void InitGraph(const string& s, Graph* graph,
- const string& device = kCPUDevice) {
- GraphDef graph_def;
-
- auto parser = protobuf::TextFormat::Parser();
- // parser.AllowRelaxedWhitespace(true);
- CHECK(parser.MergeFromString(s, &graph_def)) << s;
- GraphConstructorOptions opts;
- TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
-
- for (Node* node : graph->nodes()) {
- node->set_assigned_device_name(device);
- }
-}
-
-class MklLayoutPassTest : public ::testing::Test {
- public:
- MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
-
- void InitGraph(const string& s, const string& device = kCPUDevice) {
- ::tensorflow::InitGraph(s, &graph_, device);
- original_ = CanonicalGraphString(&graph_);
- }
-
- static bool IncludeNode(const Node* n) { return n->IsOp(); }
-
- static string EdgeId(const Node* n, int index) {
- if (index == 0) {
- return n->name();
- } else if (index == Graph::kControlSlot) {
- return strings::StrCat(n->name(), ":control");
- } else {
- return strings::StrCat(n->name(), ":", index);
- }
- }
-
- string CanonicalGraphString(Graph* g) {
- std::vector<string> nodes;
- std::vector<string> edges;
- for (const Node* n : g->nodes()) {
- if (IncludeNode(n)) {
- nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
- }
- }
- for (const Edge* e : g->edges()) {
- if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
- edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
- EdgeId(e->dst(), e->dst_input())));
- }
- }
- // Canonicalize
- std::sort(nodes.begin(), nodes.end());
- std::sort(edges.begin(), edges.end());
- return strings::StrCat(str_util::Join(nodes, ";"), "|",
- str_util::Join(edges, ";"));
- }
-
- string DoMklLayoutOptimizationPass() {
- string before = CanonicalGraphString(&graph_);
- LOG(ERROR) << "Before MKL layout rewrite pass: " << before;
-
- std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
- RunMklLayoutRewritePass(ug);
-
- string result = CanonicalGraphString(&graph_);
- LOG(ERROR) << "After MKL layout rewrite pass: " << result;
- return result;
- }
-
- const string& OriginalGraph() const { return original_; }
-
- Graph graph_;
- string original_;
-};
-
-REGISTER_OP("Input").Output("o: float").SetIsStateful();
-REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
-REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
-REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
-REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
-REGISTER_OP("_MklInput2")
- .Output("o: uint8")
- .Output("o1: uint8")
- .SetIsStateful();
-
-/////////////////////////////////////////////////////////////////////
-// Unit tests related to node merge optiimization
-/////////////////////////////////////////////////////////////////////
-
-TEST_F(MklLayoutPassTest, Basic) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Zeta);D(Zeta)|"
- "A->C;A->D;B->C:1;B->D:1");
-}
-
-// Test set 1: Conv2D + AddBias
-
-// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved ordering)
-// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous ordering)
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C', 'D'] }"
- "node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['E', 'Y']}");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
- "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->E;"
- "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;"
- "N->E:4;Y->Z:1");
-}
-
-// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for interleaved)
-// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Zeta(E,Y) (for contiguous)
-// Test for correct output slots selected
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput2'}"
- "node { name: 'N' op: '_MklInput2'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M:1', 'N:1']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C', 'D'] }"
- "node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['E', 'Y']}");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
- "M(_MklInput2);N(_MklInput2);Y(Input);Z(Zeta)|A->E;"
- "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;"
- "M:1->E:3;N:1->E:4;Y->Z:1");
-}
-
-// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Zeta(E,Y);
-// This is a case of node rewrite followed by node merge.
-// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
-// with BiasAdd to produce _MklConv2DWithBias.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C', 'D'] }"
- "node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['E', 'Y']}");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Zeta)|"
- "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
- "DMT/_2->E:5;E->Z;Y->Z:1");
-}
-
-// Graph contains only _MklConv2D, no AddBias.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M', 'N']}");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|"
- "A->C;B->C:1;M->C:2;N->C:3");
-}
-
-// _MklConv2D output does not go to BiasAdd.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd.
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
- "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
-}
-
-// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Zeta).
-// Merge should not be done in such case.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D', 'E'] }" // Conv2D has two outputs.
- // No merge should happen.
- "node { name: 'G' op: 'Zeta'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
- "G(Zeta);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
- "E->F:1;E->G:1;M->C:2;N->C:3");
-}
-
-// data_format attribute value mismatch. Merge should not be done
-// in such case.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHCW' } }"
- " input: ['C', 'D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);"
- "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
-}
-
-// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
-// rewrite tests
-
-// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
-// and BackpropInput
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'O' op: '_MklInput'}"
- "node { name: 'D' op: '_MklConv2DWithBias'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'A']}"
- "node { name: 'F' op: 'Int32Input'}"
- "node { name: 'G' op: '_MklConv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
- "node { name: 'H' op: 'Int32Input'}"
- "node { name: 'I' op: '_MklConv2DBackpropInput'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
- "node { name: 'J' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
- "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
- "I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);"
- "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;"
- "B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;"
- "E:control->DMT/_0:control;F->G:1;H->I;M->D:3;M->G:3;M->I:3;"
- "N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
-}
-
-// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
-// and BackpropInput. But nodes do not match criteria for rewrite. So
-// rewrite should not happen.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'O' op: '_MklInput'}"
- "node { name: 'D' op: '_MklConv2DWithBias'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'A']}"
- "node { name: 'F' op: 'Int32Input'}"
- "node { name: 'G' op: '_MklConv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
- "node { name: 'H' op: 'Int32Input'}"
- "node { name: 'I' op: '_MklConv2DBackpropInput'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
- "node { name: 'J' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
- "I(_MklConv2DBackpropInput);J(BiasAddGrad);"
- "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
- "B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
- "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
-}
-
-// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter
-// and BackpropInput. But nodes do not match criteria for rewrite. So
-// rewrite should not happen.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'O' op: '_MklInput'}"
- "node { name: 'D' op: '_MklConv2DWithBias'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['B', 'A', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'A']}"
- "node { name: 'F' op: 'Int32Input'}"
- "node { name: 'G' op: '_MklConv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
- "node { name: 'H' op: 'Int32Input'}"
- "node { name: 'I' op: '_MklConv2DBackpropInput'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['H', 'B', 'E', 'M', 'N', 'O']}"
- "node { name: 'J' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);"
- "I(_MklConv2DBackpropInput);J(BiasAddGrad);"
- "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
- "B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;"
- "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5");
-}
-
-// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'O' op: '_MklInput'}"
- "node { name: 'D' op: '_MklConv2DWithBias'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'A']}"
- "node { name: 'F' op: 'Int32Input'}"
- "node { name: 'G' op: '_MklConv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
- "node { name: 'H' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
- "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);"
- "H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
- "O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;"
- "E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;"
- "N->D:4;N->G:4;O->D:5;O->G:5");
-}
-
-// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
-// But BackpropFilter node inputs do not satisfy criteria for rewrite.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'O' op: '_MklInput'}"
- "node { name: 'D' op: '_MklConv2DWithBias'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'A']}"
- "node { name: 'F' op: 'Int32Input'}"
- "node { name: 'G' op: '_MklConv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['E', 'F', 'A', 'M', 'N', 'O'] }"
- "node { name: 'H' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
- "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;"
- "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
- "O->G:5");
-}
-
-// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only
-// But BackpropFilter node inputs do not satisfy criteria for rewrite.
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'O' op: '_MklInput'}"
- "node { name: 'D' op: '_MklConv2DWithBias'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['B', 'A', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'A']}"
- "node { name: 'F' op: 'Int32Input'}"
- "node { name: 'G' op: '_MklConv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'F', 'E', 'M', 'N', 'O'] }"
- "node { name: 'H' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Zeta);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);"
- "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;"
- "C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;"
- "O->G:5");
-}
-
-// No _MklConv2DWithBias in context, but _MklConv2D in context.
-// No rewrite for BiasAddGrad should happen.
-// C=_MklConv2D(A,M,B,N); D=Zeta(C,A); E=BiasAddGrad(D) (for interleaved)
-// C=_MklConv2D(A,B,M,N); D=Zeta(C,A); E=BiasAddGrad(D) (for contiguous)
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M', 'N']}"
- "node { name: 'D' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Zeta);E(BiasAddGrad);"
- "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
- "M->C:2;N->C:3");
-}
-
-// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
-// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Polygamma'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No Conv2D in the context for BiasAddGrad, but MatMul in context.
-// Rewrite should happen, but name of BiasAddGrad does not change.
-// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'MatMul'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'transpose_a' value { b: false } }"
- " attr { key: 'transpose_b' value { b: false } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
-// C=MatMul(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
-TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'MatMul'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'transpose_a' value { b: false } }"
- " attr { key: 'transpose_b' value { b: false } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(MatMul);D(Zeta);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No MatMul in the context for BiasAddGrad. No rewrite should happen.
-// C=Polygamma(A,B); D=Zeta(C,A); E=BiasAddGrad(D)
-TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Polygamma'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'A']}"
- "node { name: 'E' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Polygamma);D(Zeta);E(BiasAddGrad)|"
- "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-/////////////////////////////////////////////////////////////////////
-// Unit tests related to rewriting node to Mkl node
-/////////////////////////////////////////////////////////////////////
-
-// Single Conv2D Op; No Mkl layer on the input and on the output.
-// We will generate dummy Mkl tensor as 2nd input of Conv2D.
-TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['B', 'C'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Zeta);DMT/_0(Const);"
- "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
- "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
- "DMT/_1->C:3");
-}
-
-// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
-// have 2 outputs, both of which will be inputs to next Conv2D.
-TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'C']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->C;A->D;"
- "A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
- "C:2->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
-}
-
-// Conv2D with INT32 which is not supported by Mkl
-TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
- InitGraph(
- "node { name: 'A' op: 'HalfInput'}"
- "node { name: 'B' op: 'HalfInput'}"
- "node { name: 'C' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_HALF } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_HALF } }"
- " input: ['B', 'C'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(HalfInput);B(HalfInput);C(Conv2D);D(Zeta)|"
- "A->C;B->C:1;B->D;C->D:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Int32Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Conv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'C']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
- "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
- "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
- "DMT/_1->D:4;DMT/_2->D:5");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Int32Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Conv2DBackpropInput'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['B', 'A', 'C']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
- "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|"
- "A->D:1;A->E;B->D;B:control->DMT/_0:control;"
- "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
- "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
-}
-
-// Concat Op test: Concat with no Mkl layer feeding it
-TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
- InitGraph(
- "node { name: 'A' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'B' op: 'InputList'"
- " attr { key: 'N' value { i: 2 } }}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Concat'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['A', 'B:0', 'B:1']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D'] }");
- EXPECT_EQ(
- DoMklLayoutOptimizationPass(),
- "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;A:control->DMT/_0:control;"
- "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
- "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
-}
-
-// Concat with 2 Mkl layers feeding it
-TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'F' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['C', 'D']}"
- "node { name: 'G' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'H' op: 'Concat'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['G', 'E', 'F']}"
- "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'H'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
- "F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
- "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
- "DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
- "G:control->DMT/_4:control;H->I:1");
-}
-
-// Concat with 1 Mkl and 1 non-Mkl layer feeding it
-TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D']}"
- "node { name: 'G' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'H' op: 'Concat'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['G', 'E', 'F']}"
- "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'H'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
- "H(_MklConcat);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
- "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
- "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:2->H:4;F->H:2;"
- "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
-}
-
-// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
-TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
- InitGraph(
- "node { name: 'A' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'B' op: 'InputList'"
- " attr { key: 'N' value { i: 2 } }}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'ConcatV2'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'Tidx' value { type: DT_INT32 } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['B:0', 'B:1', 'A']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D:2;B->D;B:1->D:1;"
- "B:control->DMT/_0:control;B:control->DMT/_1:control;"
- "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
- "DMT/_1->D:4;DMT/_2->D:5");
-}
-
-// ConcatV2 with 2 Mkl layers feeding it
-TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'F' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['C', 'D']}"
- "node { name: 'G' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'H' op: 'ConcatV2'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'Tidx' value { type: DT_INT32 } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['E', 'F', 'G']}"
- "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'H'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
- "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
- "C:control->DMT/_0:control;C:control->DMT/_1:control;"
- "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
- "DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
- "F:2->H:4;G->H:2;H->I:1");
-}
-
-// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
-TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D']}"
- "node { name: 'G' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'H' op: 'ConcatV2'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'Tidx' value { type: DT_INT32 } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['E', 'F', 'G']}"
- "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'H'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Zeta);G(Const);"
- "H(_MklConcatV2);I(Zeta)|A->E;A->I;A:control->DMT/_0:control;"
- "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
- "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:2->H:3;"
- "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
- "G->H:2;H->I:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklRelu);C(Zeta);DMT/_0(Const)|A->B;A->C;"
- "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'ReluGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'C'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
- "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
- "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'ReluGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'C'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklRelu);C(_MklReluGrad);D(Zeta);DMT/_0(Const);"
- "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
- "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
- "DMT/_1->C:2");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'AvgPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklAvgPool);C(Zeta);DMT/_0(Const)|A->B;A->C;"
- "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Int32Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'AvgPoolGrad' "
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['A', 'B'] }"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['B', 'C'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
- "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
- "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
- "DMT/_1->C:3");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'I' op: 'Int32Input'}"
- "node { name: 'B' op: 'AvgPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'AvgPoolGrad' "
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['I', 'B'] }"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'C'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Zeta);DMT/_0(Const);"
- "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
- "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
- "I:control->DMT/_1:control");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'FusedBatchNormGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'epsilon' value { f: 0.0001 } }"
- " attr { key: 'is_training' value { b: true } }"
- " input: ['A', 'B', 'C', 'D', 'E'] }"
- "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'F'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
- "F(_MklFusedBatchNormGrad);G(Zeta)|A->F;A->G;"
- "A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
- "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
- "E->F:4;F->G:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'FusedBatchNorm'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'epsilon' value { f: 0.0001 } }"
- " attr { key: 'is_training' value { b: true } }"
- " input: ['A', 'B', 'C', 'D', 'E'] }"
- "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'F'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
- "F(_MklFusedBatchNorm);G(Zeta)|A->F;A->G;"
- "A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
- "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
- "E->F:4;F->G:1");
-}
-
-/////////////////////////////////////////////////////////////////////
-// Unit tests related to rewriting node for workspace edges
-/////////////////////////////////////////////////////////////////////
-
-/* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
-TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'LRN'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['B'] }"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'MaxPoolGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['B', 'C', 'D'] }"
- "node { name: 'F' op: 'Input'}"
- "node { name: 'G' op: 'LRNGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['E', 'F', 'B'] }"
- "node { name: 'H' op: 'Input'}"
- "node { name: 'I' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['H', 'G'] }");
- EXPECT_EQ(
- DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
- "I(Zeta)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
- "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
- "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
- "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
-}
-
-/* Test LRN->LRNGrad replacement by workspace nodes. */
-TEST_F(MklLayoutPassTest, LRN_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'LRN'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'LRNGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['C', 'D', 'B'] }"
- "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);E(_MklLRNGrad);F(Zeta)|"
- "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
- "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
- "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
-}
-
-/* Test LRN->LRNGrad replacement when only one of them is present. */
-TEST_F(MklLayoutPassTest, LRN_Negative1) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'LRN'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklLRN);C(Zeta);DMT/_0(Const)|"
- "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
-}
-
-/* Test LRN->LRNGrad replacement when only one of them is present. */
-TEST_F(MklLayoutPassTest, LRN_Negative2) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'LRNGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['A', 'B', 'C'] }"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
- "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
- "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
-}
-
-/* Test LRN->LRNGrad negative case, where single LRN feeds
- 2 LRNGrad nodes at different slots. */
-TEST_F(MklLayoutPassTest, LRN_Negative3) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'LRN'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'LRNGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['C', 'D', 'B'] }"
- "node { name: 'F' op: 'LRNGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'alpha' value { f: 0.001 } }"
- " attr { key: 'beta' value { f: 0.75 } }"
- " attr { key: 'bias' value { f: 1.0 } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'depth_radius' value { i: 2 } }"
- " input: ['C', 'B', 'D'] }"
- "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['E', 'F'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
- "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;"
- "A:control->DMT/_0:control;B->E:2;"
- "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
- "C:control->DMT/_1:control;C:control->DMT/_2:control;"
- "C:control->DMT/_3:control;C:control->DMT/_4:control;"
- "C:control->DMT/_5:control;C:control->DMT/_6:control;"
- "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
- "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
-}
-
-/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'MaxPoolGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['C', 'B', 'D'] }"
- "node { name: 'F' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'E'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Zeta)|"
- "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
- "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
- "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
-}
-
-// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
-// In this case, we will rewrite MaxPool node but workspace edges will not
-// be present.
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(_MklMaxPool);C(Zeta);DMT/_0(Const)|"
- "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
-}
-
-// Test MaxPoolGrad replacement when only one of them is present.
-// In this case, we will rewrite MaxPoolGrad and for workspace tensor and
-// its Mkl part, we will generate dummy tensor.
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'MaxPoolGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
- " input: ['A', 'B', 'C'] }"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'D'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
- "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
- "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
-}
-
-// Test MaxPool handling for batch-wise pooling (NCHW)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Test MaxPool handling for batch-wise pooling (NCHW)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Test MaxPool handling for depth-wise pooling (NHWC)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Test MaxPool handling for depth-wise pooling (NCHW)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Test MaxPool handling for batch-wise pooling (NHWC)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHWC' } }"
- " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Test MaxPool handling for batch-wise pooling (NHWC)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHWC' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Test MaxPool handling for depth-wise pooling (NHWC)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHWC' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Test MaxPool handling for depth-wise pooling (NHWC)
-// No rewrite should take place in such case
-TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHWC' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }");
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-/////////////////////////////////////////////////////////////////////
-
-// Single Conv2D Op on GPU device
-// No rewrite should happen
-TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Conv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B']}"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['B', 'C'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Conv2D);D(Zeta)|A->C;B->C:1;B->D;C->D:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'O' op: '_MklInput'}"
- "node { name: 'D' op: '_MklConv2DWithBias'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
- "node { name: 'E' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['D', 'A']}"
- "node { name: 'F' op: 'BiasAddGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['E'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
- "E(Zeta);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
- "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
- "M->D:3;N->D:4;O->D:5");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Int32Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Conv2DBackpropFilter'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'C']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'D'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Zeta)|"
- "A->D;A->E;B->D:1;C->D:2;D->E:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Relu'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Relu);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'ReluGrad'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }"
- "node { name: 'D' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'C'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(ReluGrad);D(Zeta)|A->C;A->D;B->C:1;C->D:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'MaxPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHWC' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(MaxPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'AvgPool'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NHWC' } }"
- " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'VALID' } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " input: ['A'] }"
- "node { name: 'C' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'B'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(AvgPool);C(Zeta)|A->B;A->C;B->C:1");
-}
-
-// Concat Op test: Concat with no Mkl layer feeding it
-TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'B' op: 'InputList'"
- " attr { key: 'N' value { i: 2 } }}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Concat'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['A', 'B:0', 'B:1']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Const);B(InputList);C(Input);D(Concat);E(Zeta)|A->D;"
- "B->D:1;B:1->D:2;C->E;D->E:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Const' "
- " attr { key: 'dtype' value { type: DT_INT32 } }"
- " attr { key: 'value' value { "
- " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
- " int_val: 0 } } } }"
- "node { name: 'B' op: 'InputList'"
- " attr { key: 'N' value { i: 2 } }}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'ConcatV2'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'Tidx' value { type: DT_INT32 } }"
- " attr { key: 'N' value { i: 2 } }"
- " input: ['B:0', 'B:1', 'A']}"
- "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['C', 'D'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|"
- "A->D:2;B->D;B:1->D:1;C->E;D->E:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'C' op: 'Input'}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'Input'}"
- "node { name: 'F' op: 'FusedBatchNorm'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'epsilon' value { f: 0.0001 } }"
- " attr { key: 'is_training' value { b: true } }"
- " input: ['A', 'B', 'C', 'D', 'E'] }"
- "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
- " input: ['A', 'F'] }",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(Input);D(Input);E(Input);"
- "F(FusedBatchNorm);G(Zeta)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
- "E->F:4;F->G:1");
-}
-
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
- CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
- InitGraph(
- "node { name: 'A' op: 'Input'}"
- "node { name: 'B' op: 'Input'}"
- "node { name: 'M' op: '_MklInput'}"
- "node { name: 'N' op: '_MklInput'}"
- "node { name: 'C' op: '_MklConv2D'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
- " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
- " attr { key: 'padding' value { s: 'SAME' } }"
- " input: ['A', 'B', 'M', 'N']}"
- "node { name: 'D' op: 'Input'}"
- "node { name: 'E' op: 'BiasAdd'"
- " attr { key: 'T' value { type: DT_FLOAT } }"
- " attr { key: 'data_format' value { s: 'NCHW' } }"
- " input: ['C', 'D'] }"
- "node { name: 'Y' op: 'Input'}"
- "node { name: 'Z' op: 'Zeta'"
- " attr {key: 'T' value { type: DT_FLOAT } }"
- " input: ['E', 'Y']}",
- kGPUDevice);
- EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
- "M(_MklInput);N(_MklInput);Y(Input);Z(Zeta)|A->C;"
- "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
-}
-
-/////////////////////////////////////////////////////////////////////
-
-static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
- testing::StopTiming();
- string s;
- for (int in = 0; in < 10; in++) {
- s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
- }
- random::PhiloxRandom philox(301, 17);
- random::SimplePhilox rnd(&philox);
- for (int op = 0; op < op_nodes; op++) {
- s += strings::Printf(
- "node { name: 'op%04d' op: 'Zeta' attr { key: 'T' value { "
- "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
- op, rnd.Uniform(10), rnd.Uniform(10));
- }
-
- bool first = true;
- while (iters > 0) {
- Graph* graph = new Graph(OpRegistry::Global());
- InitGraph(s, graph);
- int N = graph->num_node_ids();
- if (first) {
- testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
- first = false;
- }
- {
- testing::StartTiming();
- std::unique_ptr<Graph> ug(graph);
- RunMklLayoutRewritePass(&ug);
- testing::StopTiming();
- }
- iters -= N; // Our benchmark units are individual graph nodes,
- // not whole graphs
- // delete graph;
- }
-}
-BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
-
-} // namespace
-
-#else // INTEL_MKL_ML_ONLY
-
// NOTE: Unit tests in this file rely on a topological sorted graph for
// printing. But since sibling nodes of a node in the topologically sorted graph
// can be printed in different orders, tests may fail if the order in which
@@ -3602,8 +1739,6 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace
-#endif // INTEL_MKL_ML_ONLY
-
} // namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL