aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc277
1 files changed, 84 insertions, 193 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 625780e7c9..94741a11ff 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -247,10 +247,16 @@ namespace tensorflow {
//
// P = Conv2DWithBiasBackpropBias(O, O_m)
//
-// Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending
-// on the matching 'context'. The term context is loosely related to which
-// forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then
-// we consider it Conv2D context; if it is MatMul, then it is MatMul context.
+// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is
+// the context matching depth. If _MklConv2DWithBias is not within the context
+// matching depth, then we do not rewrite BiasAddGrad.
+
+// How many hops do we search for matching node in the backward dataflow graph?
+// We use maxhop of 10 based on empirical observations. Also, these are
+// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
+// directly goes to backward nodes, we do not expect the hop-distance
+// would be more than few nodes.
+static size_t kNodeMergeContextMaxDepth = 10;
class MklLayoutRewritePass : public GraphOptimizationPass {
public:
@@ -274,8 +280,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
- csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
- csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
"_MklConv2DWithBiasBackpropBias";
@@ -356,12 +360,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
csinfo_.mkl_conv2d_with_bias});
+ // We use maxhop of 10 based on empirical observations. Also, these are
+ // maxhops in backward data-flow graph. Since input of forward nodes
+ // (Conv2D) directly goes to backward nodes, we do not expect the
+ // hop-distance would be more than few nodes.
biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul,
- IsBiasAddGradInMatMulContext};
+ kNodeMergeContextMaxDepth};
biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad,
csinfo_.mkl_conv2d_with_bias,
- IsBiasAddGradInConv2DWithBiasContext};
+ kNodeMergeContextMaxDepth};
cinfo_.push_back(&biasaddgrad_matmul_context_);
cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_);
@@ -384,7 +392,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string node; // Name of the node to be rewritten
string fwd; // Name of the node in the forward pass that this node
// corresponds to
- std::function<bool(const Node*, const Node**, void* c)> context_match_fn;
+ size_t max_hop; // Maximum number of hops the fwd is located
+ // from this node. If the fwd is farther than max_hop
+ // then we do not rewrite the node.
} ContextInfo;
/// Structure to specify the name of an original node, its new name after
@@ -428,7 +438,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Structure to store all constant strings
/// NOTE: names are alphabetically sorted.
- typedef struct {
+ struct {
string avg_pool;
string avg_pool_grad;
string bias_add;
@@ -447,15 +457,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string max_pool;
string max_pool_grad;
string mkl_conv2d;
- string mkl_conv2d_grad_input;
- string mkl_conv2d_grad_filter;
string mkl_conv2d_with_bias;
string mkl_conv2d_with_bias_backprop_bias;
string relu;
string relu_grad;
string reshape;
string split;
- } ConstStringsInfo;
+ } csinfo_;
private:
/// Maintain info about nodes to rewrite
@@ -470,9 +478,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
/// Maintain info about nodes to rewrite
static std::vector<ContextInfo*> cinfo_;
- /// Maintain structure of constant strings
- static ConstStringsInfo csinfo_;
-
/// Context variables used in referencing rules
static ContextInfo biasaddgrad_matmul_context_;
static ContextInfo biasaddgrad_conv2dwithbias_context_;
@@ -624,173 +629,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
- // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node
- // specified in contextinfo 'ci'. Function updates fwd_node to point
- // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias.
- //
- // Association checks for one of the following graphs:
- //
- // Graph A:
- //
- // _ = Conv2DWithBias(F, I, _)
- // ..
- // _ = Conv2DBackpropFilter(F, _, G)
- // _ = Conv2DBackpropInput(_, I, G)
- // _ = BiasAddGrad(G)
- //
- // OR
- //
- // Graph B:
- //
- // _ = Conv2DWithBias(F, _, _)
- // ..
- // _ = Conv2DBackpropFilter(F, _, G)
- // _ = BiasAddGrad(G)
- //
- // Here F, G, and I are graph nodes; _ represents graph nodes that we
- // don't care here.
- //
- // @return - true (if BiasAddGrad is associated with Conv2DWithBias);
- // false otherwise.
- static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n,
- const Node** fwd_node,
- void* ci) {
- CHECK_NOTNULL(n);
- CHECK_NOTNULL(fwd_node);
- CHECK_NOTNULL(ci);
- *fwd_node = nullptr;
-
- CHECK_EQ(n->type_string(), csinfo_.bias_add_grad);
-
- // Get the only 1 input of BiasAddGrad.
- CHECK_EQ(n->num_inputs(), 1);
- const Node* bias_add_grad_inp = nullptr;
- TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp));
- CHECK_NOTNULL(bias_add_grad_inp);
-
- // Check if this input also goes to BackpropFilter and BackpropInput
- // as 3rd input.
- bool found_backprop_input = false;
- bool found_backprop_filter = false;
- Node* backprop_filter_node = nullptr;
- Node* backprop_input_node = nullptr;
-
- for (const Edge* e : bias_add_grad_inp->out_edges()) {
- Node* third_input = nullptr;
- if (e->dst()->type_string() == csinfo_.conv2d_grad_input ||
- e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) {
- // Third input (index 2) of BackpropInput
- TF_CHECK_OK(e->dst()->input_node(2, &third_input));
- // Third input (index 2) of BackpropInput must be same as the input
- // of BiasAddGrad.
- if (third_input == bias_add_grad_inp) {
- found_backprop_input = true;
- backprop_input_node = e->dst();
- }
- }
-
- if (e->dst()->type_string() == csinfo_.conv2d_grad_filter ||
- e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) {
- // Third input (index 2) of BackpropFilter
- TF_CHECK_OK(e->dst()->input_node(2, &third_input));
- // Third input (index 2) of BackpropFilter must be same as the input
- // of BiasAddGrad.
- if (third_input == bias_add_grad_inp) {
- found_backprop_filter = true;
- backprop_filter_node = e->dst();
- }
- }
-
- // If we found both the nodes, then we can stop the search.
- if (found_backprop_input && found_backprop_filter) {
- break;
- }
- }
-
- // If BackpropFilter node is not found, then this is not
- // Conv2DWithBias context. For 2nd graph in the example above, only
- // BackpropFilter would be present.
- if (!found_backprop_filter) {
- return false;
- }
-
- // Otherwise, we found the nodes.
- CHECK_NOTNULL(backprop_filter_node);
- if (found_backprop_input) {
- CHECK_NOTNULL(backprop_input_node);
- }
-
- // Now that we confirmed that this is Conv2DWithBias context, we need to
- // get access to the forward node (Conv2DWithBias). 2nd input of
- // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st
- // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter
- // (This comes from definition of gradient computation for Conv2D).
- if (found_backprop_input) {
- // Graph A in the example.
- Node* second_inp_of_input = nullptr;
- Node* first_inp_of_filter = nullptr;
- TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input));
- TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
- CHECK_NOTNULL(second_inp_of_input);
- CHECK_NOTNULL(first_inp_of_filter);
-
- // Now we need to find out Conv2DWithBias node from these input nodes.
- // Conv2DWithBias node is the node that accepts both the nodes
- // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots.
- for (const Edge* fe : first_inp_of_filter->out_edges()) {
- if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
- fe->dst_input() == 0) {
- for (const Edge* ie : second_inp_of_input->out_edges()) {
- if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
- ie->dst_input() == 1 && fe->dst() == ie->dst()) {
- VLOG(1) << "MklLayoutRewritePass: found "
- << fe->dst()->DebugString()
- << " as the forward node for matching context, backward"
- << " node is: " << n->DebugString();
- *fwd_node = fe->dst();
- return true;
- }
- }
- }
- }
- } else {
- // We did not find BackpropInput, so we work with BackpropFilter only.
- // Graph B in the example.
- Node* first_inp_of_filter = nullptr;
- TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter));
- CHECK_NOTNULL(first_inp_of_filter);
-
- // Now we need to find out Conv2DWithBias node from first input of
- // BackpropFIlter. Conv2DWithBias node is the node that accepts
- // first_inp_of_filter in 1st input slot.
- for (const Edge* fe : first_inp_of_filter->out_edges()) {
- if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias &&
- fe->dst_input() == 0) {
- VLOG(1) << "MklLayoutRewritePass: found "
- << fe->dst()->DebugString()
- << " as the forward node for matching context, backward"
- << " node is: " << n->DebugString();
- *fwd_node = fe->dst();
- return true;
- }
- }
- }
-
- return false;
- }
-
- // Is BiasAddGrad node in 'n' is associated with MatMul node
- // specified in contextinfo 'ci'. Function does not update fwd_node.
- //
- // @return - true (if BiasAddGrad is associated with MatMul);
- // false otherwise.
- static bool IsBiasAddGradInMatMulContext(const Node* n,
- const Node** fwd_node,
- void* ci) {
- return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci));
- }
-
-
// Rewrite rule that uses context-information for matching,
// used in scenario 2.
//
@@ -801,6 +639,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static bool ContextMatchRewrite(const Node* n, const ContextInfo* c);
// Helper function that searches the matching contextinfo for the node.
+ // Implements depth-first search in the data dependence graph for the
+ // gradient op in the backward direction.
//
// @input n - Node (gradient op) whose contextinfo is to be searched,
// fwd_node - pointer to node from the forward pass that this node
@@ -948,7 +788,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
Node* orig_node);
};
-MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
MklLayoutRewritePass::ContextInfo
MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_;
MklLayoutRewritePass::ContextInfo
@@ -1828,12 +1667,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
const ContextInfo* ci = nullptr;
bool is_context_based_rewrite = false;
if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
+ CHECK_NOTNULL(fwd_node);
is_context_based_rewrite = true;
// Sanity checks for context-based rewrite (if any)
if (orig_node->type_string() == csinfo_.bias_add_grad &&
ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
- CHECK_NOTNULL(fwd_node);
DataType orig_T, ctx_T;
string orig_data_format, ctx_data_format;
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
@@ -1945,17 +1784,69 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
CHECK_NOTNULL(fwd_node);
*fwd_node = nullptr;
- // Search for matching contextinfo based on node name and call
- // callback function using matching contextinfo.
- // There could be more than one matching contextinfos but whichever
- // matches first is returned.
+ // Search for matching contextinfo based on node name.
+ // There could be more than one matching contextinfos.
+ bool is_matching_cinfo_found = false;
+ std::vector<const ContextInfo*> mci;
for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) {
- if (n->type_string() == (*ci)->node &&
- (*ci)->context_match_fn(n, fwd_node, *ci)) {
- VLOG(1) << "Found context as matching: " << (*ci)->fwd;
- return *ci;
+ if (n->type_string() == (*ci)->node) {
+ mci.push_back(*ci);
+ is_matching_cinfo_found = true;
}
}
+ // If no matching contextinfo is found, return immediately.
+ if (!is_matching_cinfo_found) {
+ return nullptr;
+ }
+
+ VLOG(1) << "MklLayoutRewritePass: Searching graph for: " << n->type_string()
+ << " in backwards.";
+
+ // Now we will check for forward op name for context info in data
+ // flow graph. Get the max hops we should search for the fwd node.
+ // We are now going to search (breadth-first) backwards in data
+ // dependence graph (for up to max hops) from n for the node
+ // specified in fwd.
+ // queue to maintain nodes to be visited and depth info for
+ // breadth-first search
+ std::queue<std::pair<const Node*, int>> nqueue;
+ const Node* curr_node = n;
+ size_t curr_depth = 0;
+ nqueue.push(std::make_pair(curr_node, curr_depth));
+
+ while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
+ std::pair<const Node*, int> curr_pair = nqueue.front();
+ nqueue.pop();
+
+ std::set<const Node*> visited_nodes;
+ curr_node = curr_pair.first;
+ curr_depth = curr_pair.second;
+ CHECK_NOTNULL(curr_node);
+
+ VLOG(1) << "MklLayoutRewritePass: Visiting node: "
+ << curr_node->type_string() << " at depth: " << curr_depth
+ << " for node: " << n->type_string();
+
+ // If we find a match, we return immediately.
+ for (const ContextInfo* ci : mci) {
+ if (curr_node->type_string() == ci->fwd) {
+ *fwd_node = curr_node;
+ return ci;
+ }
+ }
+
+ // Else we explore backward edges from current node.
+ // Add the source nodes of all incoming edges of the node to the queue.
+ for (const Edge* e : curr_node->in_edges()) {
+ // We do not visit already visited node.
+ if (visited_nodes.find(e->src()) == visited_nodes.end()) {
+ // Depth of these nodes is 1 more than the depth of current node.
+ nqueue.push(std::make_pair(e->src(), curr_depth + 1));
+ visited_nodes.insert(e->src());
+ }
+ }
+ } /* while */
+
return nullptr;
}