aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc595
1 files changed, 385 insertions, 210 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 09b632a165..94741a11ff 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/util/mkl_util.h"
@@ -272,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
+ csinfo_.identity = "Identity";
csinfo_.lrn = "LRN";
csinfo_.lrn_grad = "LRNGrad";
csinfo_.matmul = "MatMul";
@@ -280,51 +282,75 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
- "_MklConv2DWithBiasBackpropBias";
- csinfo_.relu = "Relu";
- csinfo_.reshape = "Reshape";
- csinfo_.relu_grad = "ReluGrad";
- csinfo_.split = "Split";
+ "_MklConv2DWithBiasBackpropBias";
+ csinfo_.relu = "Relu";
+ csinfo_.relu_grad = "ReluGrad";
+ csinfo_.reshape = "Reshape";
+ csinfo_.split = "Split";
// NOTE: names are alphabetically sorted.
- rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
- CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool,
+ GetMklOpName(csinfo_.avg_pool),
+ CopyAttrsPooling, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.avg_pool_grad,
- GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
- AlwaysRewrite});
- rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
- CopyAttrsConcat, AlwaysRewrite});
- rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
- CopyAttrsConcatV2, AlwaysRewrite});
- rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
- CopyAttrsConv2D, AlwaysRewrite});
+ GetMklOpName(csinfo_.avg_pool_grad),
+ CopyAttrsPooling, AlwaysRewrite, nullptr});
+ // BiasAddGrad gets written into Conv2DWithBiasBackpropBias depending
+ // on if context contains Conv2D.
+ rinfo_.push_back({csinfo_.bias_add_grad,
+ csinfo_.mkl_conv2d_with_bias_backprop_bias,
+ CopyAttrsBiasAddGrad, ContextMatchRewrite,
+ &biasaddgrad_conv2dwithbias_context_});
+ // BiasAddGrad gets written into BiasAddGrad depending on if context
+ // contains MatMul.
+ rinfo_.push_back({csinfo_.bias_add_grad, csinfo_.matmul,
+ CopyAttrsBiasAddGrad, ContextMatchRewrite,
+ &biasaddgrad_matmul_context_});
+ rinfo_.push_back({csinfo_.concat,
+ GetMklOpName(csinfo_.concat),
+ CopyAttrsConcat, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.concatv2,
+ GetMklOpName(csinfo_.concatv2),
+ CopyAttrsConcatV2, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.conv2d,
+ GetMklOpName(csinfo_.conv2d),
+ CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
- GetMklOpName(csinfo_.conv2d_grad_filter), 3,
- CopyAttrsConv2D, AlwaysRewrite});
+ GetMklOpName(csinfo_.conv2d_grad_filter),
+ CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.conv2d_grad_input,
- GetMklOpName(csinfo_.conv2d_grad_input), 3,
- CopyAttrsConv2D, AlwaysRewrite});
+ GetMklOpName(csinfo_.conv2d_grad_input),
+ CopyAttrsConv2D, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm,
- GetMklOpName(csinfo_.fused_batch_norm), 5,
- CopyAttrsFusedBatchNorm, AlwaysRewrite});
+ GetMklOpName(csinfo_.fused_batch_norm),
+ CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
- GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
- CopyAttrsFusedBatchNorm, AlwaysRewrite});
- rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
- AlwaysRewrite});
- rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
- CopyAttrsLRN, AlwaysRewrite});
- rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
- CopyAttrsPooling, AlwaysRewrite});
+ GetMklOpName(csinfo_.fused_batch_norm_grad),
+ CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.identity,
+ GetMklOpName(csinfo_.identity),
+ CopyAttrsIdentity, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.lrn,
+ GetMklOpName(csinfo_.lrn),
+ CopyAttrsLRN, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.lrn_grad,
+ GetMklOpName(csinfo_.lrn_grad),
+ CopyAttrsLRN, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.max_pool,
+ GetMklOpName(csinfo_.max_pool),
+ CopyAttrsPooling, NonDepthBatchWisePoolRewrite, nullptr});
rinfo_.push_back({csinfo_.max_pool_grad,
- GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
- AlwaysRewrite});
- rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
- CopyAttrsRelu, AlwaysRewrite});
- rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
- CopyAttrsReshape, AlwaysRewrite});
-
- // TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet.
+ GetMklOpName(csinfo_.max_pool_grad),
+ CopyAttrsPooling, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.relu,
+ GetMklOpName(csinfo_.relu),
+ CopyAttrsRelu, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.relu_grad,
+ GetMklOpName(csinfo_.relu_grad),
+ CopyAttrsRelu, AlwaysRewrite, nullptr});
+ rinfo_.push_back({csinfo_.reshape,
+ GetMklOpName(csinfo_.reshape),
+ CopyAttrsReshape, AlwaysRewrite, nullptr});
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
@@ -338,8 +364,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// maxhops in backward data-flow graph. Since input of forward nodes
// (Conv2D) directly goes to backward nodes, we do not expect the
// hop-distance would be more than few nodes.
- cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
- kNodeMergeContextMaxDepth});
+ biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
+ kNodeMergeContextMaxDepth};
+
+ biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
+ csinfo_.mkl_conv2d_with_bias,
+ kNodeMergeContextMaxDepth};
+
+ cinfo_.push_back(&biasaddgrad_matmul_context_);
+ cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
}
// Standard interface to run pass
@@ -354,7 +387,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @return true, if and only if graph is mutated; false otherwise.
bool RunPass(std::unique_ptr<Graph>* g);
- private:
+ /// Structure to specify the context information used in a node rewrite rule
+ typedef struct {
+ string node; // Name of the node to be rewritten
+ string fwd; // Name of the node in the forward pass that this node
+ // corresponds to
+ size_t max_hop; // Maximum number of hops the fwd is located
+ // from this node. If the fwd is farther than max_hop
+ // then we do not rewrite the node.
+ } ContextInfo;
+
/// Structure to specify the name of an original node, its new name after
/// rewrite, the number of inputs to the original node, the function to
/// be used to copy attributes for the op, and the rule (if any) which
@@ -362,11 +404,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
typedef struct {
string name; // Original name of op of the node in the graph
string new_name; // New name of the op of the node in the graph
- int num_ins; // The number of inputs to the original op type
// A function handler to copy attributes from an old node to a new node.
std::function<void(const Node*, NodeBuilder*)> copy_attrs;
- std::function<bool(const Node*)> rewrite_rule; // A rule under which to
- // rewrite this node.
+ // A rule under which to rewrite this node
+ std::function<bool(const Node*, const ContextInfo* c)> rewrite_rule;
+ // ContextInfo, if any, to be used for rewrite
+ ContextInfo* context;
} RewriteInfo;
/// Structure to specify a forward op, a backward op, and the slot numbers
@@ -393,16 +436,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string new_node; // Name of the node after merge
} MergeInfo;
- /// Structure to specify the context information used in a node rewrite rule
- typedef struct {
- string node; // Name of the node to be rewritten
- string fwd; // Name of the node in the forward pass that this node
- // corresponds to
- size_t max_hop; // Maximum number of hops the fwd is located
- // from this node. If the fwd is farther than max_hop
- // then we do not rewrite the node.
- } ContextInfo;
-
/// Structure to store all constant strings
/// NOTE: names are alphabetically sorted.
struct {
@@ -417,6 +450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string conv2d_grad_filter;
string fused_batch_norm;
string fused_batch_norm_grad;
+ string identity;
string lrn;
string lrn_grad;
string matmul;
@@ -427,10 +461,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string mkl_conv2d_with_bias_backprop_bias;
string relu;
string relu_grad;
- string split;
string reshape;
+ string split;
} csinfo_;
+ private:
/// Maintain info about nodes to rewrite
std::vector<RewriteInfo> rinfo_;
@@ -441,7 +476,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<MergeInfo> minfo_;
/// Maintain info about nodes to rewrite
- static std::vector<ContextInfo> cinfo_;
+ static std::vector<ContextInfo*> cinfo_;
+
+ /// Context variables used in referencing rules
+ static ContextInfo biasaddgrad_matmul_context_;
+ static ContextInfo biasaddgrad_conv2dwithbias_context_;
/// Hash table to maintain nodes visited in the graph.
std::unordered_set<const Node*> visited_nodes_;
@@ -464,19 +503,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Clear all visited nodes
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
- // Is this a graph node that can accept variable number of inputs?
- // Return true if yes, false otherwise.
- //
- // Concat, Split are vararg nodes.
- inline bool IsVarArgNode(Node* n) {
- if (n->type_string() == csinfo_.concat ||
- n->type_string() == csinfo_.concatv2 ||
- n->type_string() == csinfo_.split) {
- return true;
- }
- return false;
- }
-
// Is OpDef::ArgDef a list type? It could be N * T or list(type).
// Refer to opdef.proto for details of list type.
inline bool ArgIsList(const OpDef::ArgDef& arg) const {
@@ -510,6 +536,39 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return string(kMklOpPrefix) + name;
}
+ // Can op represented by node 'n' run on DEVICE_CPU?
+ // Op can run on CPU with MKL if the runtime assigned device or the
+ // user requested device contains device CPU, or both are empty.
+ bool CanOpRunOnCPUDevice(const Node* n) {
+ bool result = true;
+ string reason;
+
+ // Substring that should be checked for in device name for CPU device.
+ const char* const kCPUDeviceSubStr = "cpu";
+
+ // If Op has been specifically assigned to a non-CPU device, then No.
+ if (!n->assigned_device_name().empty() &&
+ !StringPiece(n->assigned_device_name()).contains(kCPUDeviceSubStr)) {
+ result = false;
+ reason = "Op has been assigned a runtime device that is not CPU.";
+ }
+
+ // If user has specifically assigned this op to a non-CPU device, then No.
+ if (!n->def().device().empty() &&
+ !StringPiece(n->def().device()).contains(kCPUDeviceSubStr)) {
+ result = false;
+ reason = "User has assigned a device that is not CPU.";
+ }
+
+ if (result == false) {
+ VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
+ << n->type_string() << ", reason: " << reason;
+ }
+
+ // Otherwise Yes.
+ return result;
+ }
+
// Return a node that can be merged with input node 'n'
//
// @return pointer to the node if we can find such a
@@ -538,13 +597,46 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Default rewrite rule to be used in scenario 1 for rewrite.
// @return - true (since we want to always rewrite)
- static bool AlwaysRewrite(const Node* n) { return true; }
- // Rewrite rule that uses context-information for matching
+ static bool AlwaysRewrite(const Node* n, const ContextInfo* c = nullptr) {
+ return true;
+ }
+
+ // Check if we are performing pooling on depth or batch. If it is, then we
+ // do not rewrite MaxPool node to Mkl version.
+ // @return - true (if it is not a depth/batch wise pooling case);
+ // false otherwise.
+ static bool NonDepthBatchWisePoolRewrite(const Node* n,
+ const ContextInfo* c) {
+ CHECK_NOTNULL(n);
+
+ string data_format_str;
+ TensorFormat data_format;
+ std::vector<int32> ksize, strides;
+ CHECK_EQ(GetNodeAttr(n->def(), "ksize", &ksize).ok(), true);
+ CHECK_EQ(GetNodeAttr(n->def(), "strides", &strides).ok(), true);
+ CHECK_EQ(GetNodeAttr(n->def(), "data_format", &data_format_str).ok(),
+ true);
+ CHECK_EQ(FormatFromString(data_format_str, &data_format), true);
+
+ // Condition that specifies non-batch-wise and non-depth-wise pooling.
+ if (GetTensorDim(ksize, data_format, 'N') == 1 &&
+ GetTensorDim(strides, data_format, 'N') == 1 &&
+ GetTensorDim(ksize, data_format, 'C') == 1 &&
+ GetTensorDim(strides, data_format, 'C') == 1) {
+ return true;
+ }
+
+ return false;
+ }
+
+ // Rewrite rule that uses context-information for matching,
// used in scenario 2.
//
// @input - Node 'n' for which to search for matching context
- // @return - true if matching context is found; false otherwise.
- static bool ContextMatchRewrite(const Node* n);
+ // @input - The context 'c' under which to rewrite
+ // @return - true if we can rewrite node under context 'c';
+ // false otherwise.
+ static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
// Helper function that searches the matching contextinfo for the node.
// Implements depth-first search in the data dependence graph for the
@@ -598,6 +690,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// node that we are constructing.
//
// @input g - input graph,
+ // @input orig_node - Original node that we are rewriting
// @input inputs - inputs to old node that we are using for constructing
// new inputs,
// @input input_idx - the index in the 'inputs' vector pointing to the
@@ -608,11 +701,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output output_nodes - the list of new nodes creating Mkl tensors
//
// @return None
- void GetNodesProducingMklTensorList(
- std::unique_ptr<Graph>* g,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
- int* input_idx, int list_length,
- std::vector<NodeBuilder::NodeOut>* output_nodes);
+ void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
+ Node* orig_node, const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+ int* input_idx, int list_length,
+ std::vector<NodeBuilder::NodeOut>* output_nodes);
// Get a node that will feed an Mkl tensor to the new
// node that we are constructing. The output node could be (1) 'n'
@@ -620,6 +712,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// if 'n' is not an Mkl layer.
//
// @input g - input graph,
+ // @input orig_node - Original node that we are rewriting,
// @input n - Node based on which we are creating Mkl node,
// @input n_output_slot - the output slot of node 'n'
// which is feeding to the node that we are constructing
@@ -627,9 +720,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output mkl_node_output_slot - the slot number of mkl_node that
// will feed the tensor
// @return None
- void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
- int n_output_slot, Node** mkl_node,
- int* mkl_node_output_slot);
+ void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* orig_node,
+ Node* n, int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
@@ -680,6 +772,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb);
@@ -695,13 +788,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
Node* orig_node);
};
-std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
+MklLayoutRewritePass::ContextInfo
+ MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
+MklLayoutRewritePass::ContextInfo
+ MklLayoutRewritePass::biasaddgrad_matmul_context_;
+std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
-// We register Mkl rewrite pass for phase 1 in post rewrite group.
+// We register Mkl rewrite pass for phase 1 in post partitioning group.
// We register it here so that we get a complete picture of all users of Mkl
// nodes. Do not change the ordering of the Mkl passes.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1,
- MklLayoutRewritePass);
+const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
+ OptimizationPassRegistry::POST_PARTITIONING;
+REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -737,27 +835,14 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList(
while (list_length != 0) {
CHECK_GT(list_length, 0);
- CHECK_LE(*input_idx, inputs.size());
+ CHECK_LT(*input_idx, inputs.size());
Node* n = inputs[*input_idx].first;
int slot = inputs[*input_idx].second;
- const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
- // If input node 'n' is producing a list/array output at output
- // slot 'slot' then we need to find out the length of that list/array.
- if (ArgIsList(arg)) {
- int N = GetTensorListLength(arg, n);
- CHECK_LE(N, list_length);
- for (int j = 0; j < N; j++) {
- output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
- }
- (*input_idx)++;
- list_length -= N;
- } else {
- // But if input node 'n' is just producing a single tensor at
- // output slot 'slot' then we just add that single node.
- output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
- (*input_idx)++;
- list_length--;
- }
+ // If input node 'n' is just producing a single tensor at
+ // output slot 'slot' then we just add that single node.
+ output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
+ (*input_idx)++;
+ list_length--;
}
}
@@ -775,20 +860,39 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
TensorShape dummy_shape({8});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // the same device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // the same device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
+
+ // If number of inputs to the original node is > 0, then we add
+ // control dependency between 1st input (index 0) of the original node and
+ // the dummy Mkl node. This is needed because control-flow ops such as Enter,
+ // Merge, etc, require frame_name of the dummy Mkl node to be same as the
+ // rewritten node. Adding control edge between 1st input of the original node
+ // and the dummy Mkl node ensures that the dummy node is in the same frame
+ // as the original node. Choosing 1st input is not necessary - any input of
+ // the original node is fine because all the inputs of a node are always in
+ // the same frame.
+ if (orig_node->num_inputs() > 0) {
+ Node* orig_input0 = nullptr;
+ TF_CHECK_OK(orig_node->input_node(0,
+ const_cast<const Node**>(&orig_input0)));
+ CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
+ }
+
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
std::unique_ptr<Graph>* g,
- const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
- int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
+ Node* orig_node,
+ const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+ int* input_idx, int list_length,
+ std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes);
@@ -796,38 +900,19 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
while (list_length != 0) {
CHECK_GT(list_length, 0);
- CHECK_LE(*input_idx, inputs.size());
+ CHECK_LT(*input_idx, inputs.size());
Node* n = inputs[*input_idx].first;
int slot = inputs[*input_idx].second;
- const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
- // We need to check first if the input edge is going to carry a
- // single tensor or a list of tensors. If it is a list of tensors,
- // then we need to create list of Mkl dummy nodes.
- if (ArgIsList(arg)) {
- // If input node 'n' is producing a list/array output at output
- // slot 'slot' then we need to find out the length of that list/array.
- int N = GetTensorListLength(arg, n);
- CHECK_LE(N, list_length);
- Node* mkl_node = nullptr;
- int mkl_node_output_slot = 0;
- // If it is a list, then create a list of Mkl dummy nodes.
- for (int j = 0; j < N; j++) {
- GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
- output_nodes->push_back(
- NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
- }
- (*input_idx)++;
- list_length -= N;
- } else {
- // If it is not a list, then create a single Mkl tensor node.
- Node* mkl_node = nullptr;
- int mkl_node_output_slot = 0;
- GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
- output_nodes->push_back(
- NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
- (*input_idx)++;
- list_length--;
- }
+ // If 'n' is producing a single tensor, then create a single Mkl tensor
+ // node.
+ Node* mkl_node = nullptr;
+ int mkl_node_output_slot = 0;
+ GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
+ &mkl_node_output_slot);
+ output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
+ mkl_node_output_slot));
+ (*input_idx)++;
+ list_length--;
}
}
@@ -835,9 +920,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// node that we are constructing. An input node could be (1) 'n'
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
// if 'n' is not an Mkl layer.
-void MklLayoutRewritePass::GetNodeProducingMklTensor(
- std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
- int* mkl_node_output_slot) {
+void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
+ Node* orig_node, Node* n,
+ int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot);
@@ -860,7 +945,7 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(
// to create a dummy node that will feed a dummy Mkl tensor to this node.
// DummyMklTensor node has no input and generates only 1 output
// (dummy Mkl tensor) as output slot number 0.
- GetDummyMklTensorNode(g, mkl_node, n);
+ GetDummyMklTensorNode(g, mkl_node, orig_node);
CHECK_NOTNULL(*mkl_node);
*mkl_node_output_slot = 0;
}
@@ -926,16 +1011,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
- GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
- &new_node_inputs);
+ GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
+ N, &new_node_inputs);
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
- GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
- old_node_inputs[iidx].second, &mkl_node,
- &mkl_node_output_slot);
+ GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
+ old_node_inputs[iidx].second,
+ &mkl_node, &mkl_node_output_slot);
nb->Input(mkl_node, mkl_node_output_slot);
iidx++;
nn_slot_idx++;
@@ -1020,13 +1105,30 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
TensorShape dummy_shape({1});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
- .Attr("value", proto)
- .Attr("dtype", dt)
- .Device(orig_node->def().device()) // We place this node on
- // same the device as the
- // device of the original
- // node.
- .Finalize(&**g, out));
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Device(orig_node->def().device()) // We place this node on
+ // same the device as the
+ // device of the original
+ // node.
+ .Finalize(&**g, out));
+
+ // If number of inputs to the original node is > 0, then we add
+ // control dependency between 1st input (index 0) of the original node and
+ // the dummy Mkl node. This is needed because control-flow ops such as Enter,
+ // Merge, etc, require frame_name of the dummy Mkl node to be same as the
+ // rewritten node. Adding control edge between 1st input of the original node
+ // and the dummy Mkl node ensures that the dummy node is in the same frame
+ // as the original node. Choosing 1st input is not necessary - any input of
+ // the original node is fine because all the inputs of a node are always in
+ // the same frame.
+ if (orig_node->num_inputs() > 0) {
+ Node* orig_input0 = nullptr;
+ TF_CHECK_OK(orig_node->input_node(0,
+ const_cast<const Node**>(&orig_input0)));
+ CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
+ }
+
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
@@ -1179,6 +1281,16 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
nb->Attr("data_format", data_format);
}
+void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+}
+
void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
@@ -1235,6 +1347,19 @@ void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
nb->Attr("T", T);
}
+void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ DataType Tshape;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("Tshape", Tshape);
+}
+
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
@@ -1303,20 +1428,6 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
nb->Attr("is_training", is_training);
}
-void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
- NodeBuilder* nb) {
- DataType T;
- DataType Tshape;
-
- // Get all attributes from old node.
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
- TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
-
- // Add attributes to new node.
- nb->Attr("T", T);
- nb->Attr("Tshape", Tshape);
-}
-
//////////////////////////////////////////////////////////////////////////
// Helper functions related to node merge pass
//////////////////////////////////////////////////////////////////////////
@@ -1353,8 +1464,9 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
continue;
}
+ const int B_in = b->num_inputs();
gtl::InlinedVector<Node*, 4> b_control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
+ gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
FillInputs(b, &b_control_edges, &b_in);
// Shouldn't merge if a and b have different control edges.
@@ -1438,7 +1550,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
CHECK_EQ(succ->in_edges().size(), 2);
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
- GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node
+ GetDummyMklTensorNode(g, &oper3_mkl, pred); // Get dummy Mkl tensor node
// as BiasAdd does not have Mkl tensor as input.
CHECK_NOTNULL(oper3_mkl);
@@ -1483,9 +1595,38 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
// Set the Mkl layer label for this op.
new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
+ // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
+ // node are already copied in BuildNode. We handle control edges now.
+ for (const Edge* e : pred->in_edges()) {
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ }
+ }
+ for (const Edge* e : succ->in_edges()) {
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ }
+ }
+
// Incoming edges are fixed, we will fix the outgoing edges now.
+ // First, we will fix outgoing control edges from 'pred' node.
+ // We don't need to handle outgoing data edges from 'pred' node
+ // because pred has only 1 output going to succ node (we enforced
+ // this check for merge already).
+ for (const Edge* e : pred->out_edges()) {
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
+ }
+ }
+
+ // Second, we will fix outgoing control and data edges from 'succ' node.
for (const Edge* e : succ->out_edges()) {
- (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
+ } else {
+ CHECK_NOTNULL((*g)->AddEdge(new_node, e->src_output(), e->dst(),
+ e->dst_input()));
+ }
}
// Copy device assigned to old node to new node.
@@ -1550,18 +1691,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
"data_format or T attribute or devices of BiasAddGrad and "
"Conv2D do not match. Will skip node rewrite optimization");
}
+ } else if (orig_node->type_string() == csinfo_.bias_add_grad &&
+ ri->new_name == csinfo_.matmul) {
+ // When BiasAddGrad has MatMul in context, we do not do any rewrite
+ // and leave BiasAddGrad as it is. But we check for this condition
+ // when we check for node rewrite rule. So we should not even come
+ // here for MatMul. So we will fail now.
+ return Status(
+ error::Code::INVALID_ARGUMENT,
+ "No rewrite is required for BiasAddGrad for MatMul context.");
}
}
// Get all inputs.
- const int num = orig_node->in_edges().size();
- // Check the number of inputs against the user-specified value for non-vararg
- // nodes.
- if (!IsVarArgNode(orig_node)) {
- CHECK_EQ(num, ri->num_ins);
- }
+ const int num_inputs = orig_node->in_edges().size();
gtl::InlinedVector<Node*, 4> control_edges;
- gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num);
+ gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_inputs);
FillInputs(orig_node, &control_edges, &inputs);
// Build new node. We use same name as original node, but change the op name.
@@ -1596,8 +1741,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
TF_CHECK_OK(nb.Finalize(&**g, &new_node));
CHECK_NOTNULL(new_node);
- // Incoming edges from 'orig_node' node to new 'new_node' node are already
- // copied in BuildNode. Copy outgoing edges from 'orig_node' node to new
+ // Incoming data edges from 'orig_node' node to new 'new_node' node are
+ // already copied in BuildNode. We need to handle control edges now.
+ for (const Edge* e : orig_node->in_edges()) {
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ }
+ }
+
+ // Copy outgoing edges from 'orig_node' node to new
// 'new_node' node, since the output also follows same ordering among
// Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
// tensors appropriately. Specifically, nth output of the original node
@@ -1605,15 +1757,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
// of the tensors. For the contiguous ordering of the tensors, it will be n.
// GetTensorDataIndex provides this mapping function.
for (const Edge* e : orig_node->out_edges()) {
- // We need to handle control-edges by using their original slot number.
- // Generally, -1 is reserved for control slot.
- if (e->src_output() < 0) {
- (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
+ if (e->IsControlEdge()) {
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
} else {
- (*g)->AddEdge(
- new_node,
- GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
- e->dst(), e->dst_input());
+ CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
+ e->src()->num_outputs()),
+ e->dst(), e->dst_input()));
}
}
@@ -1640,8 +1789,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
bool is_matching_cinfo_found = false;
std::vector<const ContextInfo*> mci;
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
- if (n->type_string() == ci->node) {
- mci.push_back(&*ci);
+ if (n->type_string() == (*ci)->node) {
+ mci.push_back(*ci);
is_matching_cinfo_found = true;
}
}
@@ -1701,9 +1850,10 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
return nullptr;
}
-bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) {
+bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n,
+ const ContextInfo* c) {
const Node* fwd_node = nullptr;
- return SearchMatchingContext(n, &fwd_node) != nullptr;
+ return SearchMatchingContext(n, &fwd_node) == c;
}
const MklLayoutRewritePass::RewriteInfo*
@@ -1719,18 +1869,29 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
return nullptr;
}
- if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
- return nullptr;
+ // BiasAddGrad is not an Mkl layer, so we make an exception for it.
+ if (n->type_string() != csinfo_.bias_add_grad) {
+ if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
+ return nullptr;
+ }
}
// We support 2 types of node rewrites:
- // 1. Rewriting BiasAddGrad depending on its context.
+ // 1. Rewriting BiasAddGrad depending on its MklConv2DWithBias context.
// 2. Rewriting an op to Mkl op always
// We return true if any of these 2 conditions is met.
// Find matching RewriteInfo and then check that rewrite rule applies.
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
- if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
+ if (n->type_string().compare(ri->name) == 0 &&
+ ri->rewrite_rule(n, ri->context)) {
+ // If we are rewriting BiasAddGrad into BiasAddGrad for MatMul context,
+ // then we just return directly.
+ if (n->type_string() == csinfo_.bias_add_grad &&
+ ri->context->fwd == csinfo_.matmul &&
+ ri->new_name == csinfo_.bias_add_grad) {
+ return nullptr;
+ }
return &*ri;
}
}
@@ -1753,7 +1914,8 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
- if (!n->IsOp()) {
+ // If node is not an op or it cannot run on CPU device, then skip.
+ if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
continue;
}
@@ -1801,18 +1963,31 @@ bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
return MklLayoutRewritePass().RunPass(g);
}
-Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
- if (options.graph == nullptr) {
+Status MklLayoutRewritePass::Run(
+ const GraphOptimizationPassOptions& options) {
+ if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}
- // Get the ownership of graph
- std::unique_ptr<Graph>* g = std::move(options.graph);
-
- RunPass(g);
-
- // Return the ownership of graph back
- options.graph->reset(g->release());
+ auto process_graph = [&](std::unique_ptr<Graph>* g) {
+ // Get the ownership of a graph
+ std::unique_ptr<Graph>* ng = std::move(g);
+ RunPass(ng);
+ // Return the ownership of a graph back
+ g->reset(ng->release());
+ };
+
+ if (kMklLayoutRewritePassGroup !=
+ OptimizationPassRegistry::POST_PARTITIONING) {
+ // For any pre-partitioning phase, a graph is stored in options.graph.
+ process_graph(options.graph);
+ } else {
+ // For post partitioning phase, graphs are stored in
+ // options.partition_graphs.
+ for (auto& pg : *options.partition_graphs) {
+ process_graph(&pg.second);
+ }
+ }
return Status::OK();
}