diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 277 |
1 files changed, 84 insertions, 193 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 625780e7c9..94741a11ff 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -247,10 +247,16 @@ namespace tensorflow { // // P = Conv2DWithBiasBackpropBias(O, O_m) // -// Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending -// on the matching 'context'. The term context is loosely related to which -// forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then -// we consider it Conv2D context; if it is MatMul, then it is MatMul context. +// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is +// the context matching depth. If _MklConv2DWithBias is not within the context +// matching depth, then we do not rewrite BiasAddGrad. + +// How many hops do we search for matching node in the backward dataflow graph? +// We use maxhop of 10 based on empirical observations. Also, these are +// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D) +// directly goes to backward nodes, we do not expect the hop-distance +// would be more than few nodes. +static size_t kNodeMergeContextMaxDepth = 10; class MklLayoutRewritePass : public GraphOptimizationPass { public: @@ -274,8 +280,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.max_pool = "MaxPool"; csinfo_.max_pool_grad = "MaxPoolGrad"; csinfo_.mkl_conv2d = "_MklConv2D"; - csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; - csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = "_MklConv2DWithBiasBackpropBias"; @@ -356,12 +360,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0, csinfo_.mkl_conv2d_with_bias}); + // We use maxhop of 10 based on empirical observations. Also, these are + // maxhops in backward data-flow graph. Since input of forward nodes + // (Conv2D) directly goes to backward nodes, we do not expect the + // hop-distance would be more than few nodes. biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, - IsBiasAddGradInMatMulContext}; + kNodeMergeContextMaxDepth}; biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, - IsBiasAddGradInConv2DWithBiasContext}; + kNodeMergeContextMaxDepth}; cinfo_.push_back(&biasaddgrad_matmul_context_); cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); @@ -384,7 +392,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string node; // Name of the node to be rewritten string fwd; // Name of the node in the forward pass that this node // corresponds to - std::function<bool(const Node*, const Node**, void* c)> context_match_fn; + size_t max_hop; // Maximum number of hops the fwd is located + // from this node. If the fwd is farther than max_hop + // then we do not rewrite the node. } ContextInfo; /// Structure to specify the name of an original node, its new name after @@ -428,7 +438,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Structure to store all constant strings /// NOTE: names are alphabetically sorted. - typedef struct { + struct { string avg_pool; string avg_pool_grad; string bias_add; @@ -447,15 +457,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string max_pool; string max_pool_grad; string mkl_conv2d; - string mkl_conv2d_grad_input; - string mkl_conv2d_grad_filter; string mkl_conv2d_with_bias; string mkl_conv2d_with_bias_backprop_bias; string relu; string relu_grad; string reshape; string split; - } ConstStringsInfo; + } csinfo_; private: /// Maintain info about nodes to rewrite @@ -470,9 +478,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Maintain info about nodes to rewrite static std::vector<ContextInfo*> cinfo_; - /// Maintain structure of constant strings - static ConstStringsInfo csinfo_; - /// Context variables used in referencing rules static ContextInfo biasaddgrad_matmul_context_; static ContextInfo biasaddgrad_conv2dwithbias_context_; @@ -624,173 +629,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } - // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node - // specified in contextinfo 'ci'. Function updates fwd_node to point - // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias. - // - // Association checks for one of the following graphs: - // - // Graph A: - // - // _ = Conv2DWithBias(F, I, _) - // .. - // _ = Conv2DBackpropFilter(F, _, G) - // _ = Conv2DBackpropInput(_, I, G) - // _ = BiasAddGrad(G) - // - // OR - // - // Graph B: - // - // _ = Conv2DWithBias(F, _, _) - // .. - // _ = Conv2DBackpropFilter(F, _, G) - // _ = BiasAddGrad(G) - // - // Here F, G, and I are graph nodes; _ represents graph nodes that we - // don't care here. - // - // @return - true (if BiasAddGrad is associated with Conv2DWithBias); - // false otherwise. - static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n, - const Node** fwd_node, - void* ci) { - CHECK_NOTNULL(n); - CHECK_NOTNULL(fwd_node); - CHECK_NOTNULL(ci); - *fwd_node = nullptr; - - CHECK_EQ(n->type_string(), csinfo_.bias_add_grad); - - // Get the only 1 input of BiasAddGrad. - CHECK_EQ(n->num_inputs(), 1); - const Node* bias_add_grad_inp = nullptr; - TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp)); - CHECK_NOTNULL(bias_add_grad_inp); - - // Check if this input also goes to BackpropFilter and BackpropInput - // as 3rd input. - bool found_backprop_input = false; - bool found_backprop_filter = false; - Node* backprop_filter_node = nullptr; - Node* backprop_input_node = nullptr; - - for (const Edge* e : bias_add_grad_inp->out_edges()) { - Node* third_input = nullptr; - if (e->dst()->type_string() == csinfo_.conv2d_grad_input || - e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) { - // Third input (index 2) of BackpropInput - TF_CHECK_OK(e->dst()->input_node(2, &third_input)); - // Third input (index 2) of BackpropInput must be same as the input - // of BiasAddGrad. - if (third_input == bias_add_grad_inp) { - found_backprop_input = true; - backprop_input_node = e->dst(); - } - } - - if (e->dst()->type_string() == csinfo_.conv2d_grad_filter || - e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) { - // Third input (index 2) of BackpropFilter - TF_CHECK_OK(e->dst()->input_node(2, &third_input)); - // Third input (index 2) of BackpropFilter must be same as the input - // of BiasAddGrad. - if (third_input == bias_add_grad_inp) { - found_backprop_filter = true; - backprop_filter_node = e->dst(); - } - } - - // If we found both the nodes, then we can stop the search. - if (found_backprop_input && found_backprop_filter) { - break; - } - } - - // If BackpropFilter node is not found, then this is not - // Conv2DWithBias context. For 2nd graph in the example above, only - // BackpropFilter would be present. - if (!found_backprop_filter) { - return false; - } - - // Otherwise, we found the nodes. - CHECK_NOTNULL(backprop_filter_node); - if (found_backprop_input) { - CHECK_NOTNULL(backprop_input_node); - } - - // Now that we confirmed that this is Conv2DWithBias context, we need to - // get access to the forward node (Conv2DWithBias). 2nd input of - // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st - // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter - // (This comes from definition of gradient computation for Conv2D). - if (found_backprop_input) { - // Graph A in the example. - Node* second_inp_of_input = nullptr; - Node* first_inp_of_filter = nullptr; - TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input)); - TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); - CHECK_NOTNULL(second_inp_of_input); - CHECK_NOTNULL(first_inp_of_filter); - - // Now we need to find out Conv2DWithBias node from these input nodes. - // Conv2DWithBias node is the node that accepts both the nodes - // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots. - for (const Edge* fe : first_inp_of_filter->out_edges()) { - if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && - fe->dst_input() == 0) { - for (const Edge* ie : second_inp_of_input->out_edges()) { - if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && - ie->dst_input() == 1 && fe->dst() == ie->dst()) { - VLOG(1) << "MklLayoutRewritePass: found " - << fe->dst()->DebugString() - << " as the forward node for matching context, backward" - << " node is: " << n->DebugString(); - *fwd_node = fe->dst(); - return true; - } - } - } - } - } else { - // We did not find BackpropInput, so we work with BackpropFilter only. - // Graph B in the example. - Node* first_inp_of_filter = nullptr; - TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); - CHECK_NOTNULL(first_inp_of_filter); - - // Now we need to find out Conv2DWithBias node from first input of - // BackpropFIlter. Conv2DWithBias node is the node that accepts - // first_inp_of_filter in 1st input slot. - for (const Edge* fe : first_inp_of_filter->out_edges()) { - if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && - fe->dst_input() == 0) { - VLOG(1) << "MklLayoutRewritePass: found " - << fe->dst()->DebugString() - << " as the forward node for matching context, backward" - << " node is: " << n->DebugString(); - *fwd_node = fe->dst(); - return true; - } - } - } - - return false; - } - - // Is BiasAddGrad node in 'n' is associated with MatMul node - // specified in contextinfo 'ci'. Function does not update fwd_node. - // - // @return - true (if BiasAddGrad is associated with MatMul); - // false otherwise. - static bool IsBiasAddGradInMatMulContext(const Node* n, - const Node** fwd_node, - void* ci) { - return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); - } - - // Rewrite rule that uses context-information for matching, // used in scenario 2. // @@ -801,6 +639,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static bool ContextMatchRewrite(const Node* n, const ContextInfo* c); // Helper function that searches the matching contextinfo for the node. + // Implements depth-first search in the data dependence graph for the + // gradient op in the backward direction. // // @input n - Node (gradient op) whose contextinfo is to be searched, // fwd_node - pointer to node from the forward pass that this node @@ -948,7 +788,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { Node* orig_node); }; -MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; MklLayoutRewritePass::ContextInfo MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; MklLayoutRewritePass::ContextInfo @@ -1828,12 +1667,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, const ContextInfo* ci = nullptr; bool is_context_based_rewrite = false; if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) { + CHECK_NOTNULL(fwd_node); is_context_based_rewrite = true; // Sanity checks for context-based rewrite (if any) if (orig_node->type_string() == csinfo_.bias_add_grad && ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { - CHECK_NOTNULL(fwd_node); DataType orig_T, ctx_T; string orig_data_format, ctx_data_format; TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T)); @@ -1945,17 +1784,69 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, CHECK_NOTNULL(fwd_node); *fwd_node = nullptr; - // Search for matching contextinfo based on node name and call - // callback function using matching contextinfo. - // There could be more than one matching contextinfos but whichever - // matches first is returned. + // Search for matching contextinfo based on node name. + // There could be more than one matching contextinfos. + bool is_matching_cinfo_found = false; + std::vector<const ContextInfo*> mci; for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { - if (n->type_string() == (*ci)->node && - (*ci)->context_match_fn(n, fwd_node, *ci)) { - VLOG(1) << "Found context as matching: " << (*ci)->fwd; - return *ci; + if (n->type_string() == (*ci)->node) { + mci.push_back(*ci); + is_matching_cinfo_found = true; } } + // If no matching contextinfo is found, return immediately. + if (!is_matching_cinfo_found) { + return nullptr; + } + + VLOG(1) << "MklLayoutRewritePass: Searching graph for: " << n->type_string() + << " in backwards."; + + // Now we will check for forward op name for context info in data + // flow graph. Get the max hops we should search for the fwd node. + // We are now going to search (breadth-first) backwards in data + // dependence graph (for up to max hops) from n for the node + // specified in fwd. + // queue to maintain nodes to be visited and depth info for + // breadth-first search + std::queue<std::pair<const Node*, int>> nqueue; + const Node* curr_node = n; + size_t curr_depth = 0; + nqueue.push(std::make_pair(curr_node, curr_depth)); + + while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) { + std::pair<const Node*, int> curr_pair = nqueue.front(); + nqueue.pop(); + + std::set<const Node*> visited_nodes; + curr_node = curr_pair.first; + curr_depth = curr_pair.second; + CHECK_NOTNULL(curr_node); + + VLOG(1) << "MklLayoutRewritePass: Visiting node: " + << curr_node->type_string() << " at depth: " << curr_depth + << " for node: " << n->type_string(); + + // If we find a match, we return immediately. + for (const ContextInfo* ci : mci) { + if (curr_node->type_string() == ci->fwd) { + *fwd_node = curr_node; + return ci; + } + } + + // Else we explore backward edges from current node. + // Add the source nodes of all incoming edges of the node to the queue. + for (const Edge* e : curr_node->in_edges()) { + // We do not visit already visited node. + if (visited_nodes.find(e->src()) == visited_nodes.end()) { + // Depth of these nodes is 1 more than the depth of current node. + nqueue.push(std::make_pair(e->src(), curr_depth + 1)); + visited_nodes.insert(e->src()); + } + } + } /* while */ + return nullptr; } |