aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_optimizer_merge.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_optimizer_merge.cc')
-rw-r--r--tensorflow/core/graph/mkl_optimizer_merge.cc124
1 files changed, 90 insertions, 34 deletions
diff --git a/tensorflow/core/graph/mkl_optimizer_merge.cc b/tensorflow/core/graph/mkl_optimizer_merge.cc
index 98fc268d28..bc5915eda2 100644
--- a/tensorflow/core/graph/mkl_optimizer_merge.cc
+++ b/tensorflow/core/graph/mkl_optimizer_merge.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <vector>
#include <queue>
#include <utility>
+#include <string>
+#include <memory>
#include "tensorflow/core/graph/mkl_optimizer_merge.h"
@@ -33,6 +35,8 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
@@ -58,8 +62,8 @@ static size_t kNodeMergeContextMaxDepth = 10;
class NodeMergeRewritePass : public GraphOptimizationPass {
public:
NodeMergeRewritePass() {
- csinfo_.conv2d = "Conv2D";
- csinfo_.conv2dwithbias = "Conv2DWithBias";
+ csinfo_.conv2d = "MklConv2D";
+ csinfo_.conv2dwithbias = "MklConv2DWithBias";
csinfo_.conv2dwithbiasbackpropbias = "Conv2DWithBiasBackpropBias";
csinfo_.biasadd = "BiasAdd";
csinfo_.matmul = "MatMul";
@@ -72,6 +76,9 @@ class NodeMergeRewritePass : public GraphOptimizationPass {
// maxhops in backward data-flow graph. Since input of forward nodes
// (Conv2D) directly goes to backward nodes, we do not expect the
// hop-distance would be more than few nodes.
+ // TODO(nhasabni) Temporarily disabling rewrite of BiasAddGrad.
+ // Will enable it once we support Conv2DWithBiasBackpropBias op.
+#if 0
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
{csinfo_.conv2dwithbias, kNodeMergeContextMaxDepth}});
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
@@ -80,6 +87,7 @@ class NodeMergeRewritePass : public GraphOptimizationPass {
// because we do not have a separate Op for MatMulwithBias.
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.biasaddgrad,
{csinfo_.matmul, kNodeMergeContextMaxDepth}});
+#endif
}
// Standard interface to run optimization pass
@@ -182,10 +190,16 @@ class NodeMergeRewritePass : public GraphOptimizationPass {
// @return Matching rewriteinfo in case a match is found; null otherwise.
const RewriteInfo* FindMatchingRewriteInfo(const Node* n,
const Node** fwdn) const;
+
+ // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
+ // and return it in '*out'.
+ // TODO(nhasabni) We should move this to mkl_util.h
+ void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out);
};
-/// We register merge optimizer for phase 1 and MKLToTF insertion for phase 2.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
+// We register merge optimizer for phase 2 in pre-placement group.
+// Do not change the ordering of the Mkl passes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 2,
NodeMergeRewritePass);
static void FillInputs(const Node* n,
@@ -219,8 +233,6 @@ Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
}
}
- VLOG(1) << "FindNodeForMerge: " << a->type_string();
-
for (const MergeInfo* mi : matching_mi) {
const int N_in = a->num_inputs();
if (mi->op >= N_in) {
@@ -240,8 +252,6 @@ Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
continue;
}
- VLOG(1) << " FindNode: " << b->type_string();
-
gtl::InlinedVector<Node*, 4> b_control_edges;
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
FillInputs(b, &b_control_edges, &b_in);
@@ -258,6 +268,22 @@ Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
return nullptr;
}
+void NodeMergeRewritePass::GetDummyMklTensorNode(
+ std::unique_ptr<Graph>* g, Node** out) {
+ const DataType dt = DataTypeToEnum<uint8>::v();
+ TensorProto proto;
+ proto.set_dtype(dt);
+ uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
+ proto.set_tensor_content(const_cast<const void*>(
+ static_cast<void*>(&zero)), 8);
+ TensorShape dummy_shape({8});
+ dummy_shape.AsProto(proto.mutable_tensor_shape());
+ TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
+ .Attr("value", proto)
+ .Attr("dtype", dt)
+ .Finalize(&**g, out));
+}
+
Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g,
Node* succ, Node* pred) {
CHECK_NOTNULL(succ);
@@ -271,7 +297,6 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g,
std::vector<int32> strides;
string data_format_pred, data_format_succ;
bool use_cudnn_on_gnu;
- int groups = 1;
TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
@@ -280,25 +305,28 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g,
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu",
&use_cudnn_on_gnu));
- // Groups attribute may not be there on the input node. So we do not
- // check for error in GetNodeAttr call.
- GetNodeAttr(pred->def(), "groups", &groups);
// We check to ensure that data formats of both succ and pred are same.
// We expect them to be same, so we can enforce this as assert.
// But assert can be too strict, so we enforce this as a check.
// If the check fails, then we do not merge two nodes.
+ // We also do same check for devices.
if (data_format_pred != data_format_succ ||
- T_pred != T_succ) {
+ T_pred != T_succ ||
+ pred->assigned_device_name() != succ->assigned_device_name() ||
+ pred->def().device() != succ->def().device()) {
return Status(error::Code::INVALID_ARGUMENT,
- "data_format or T attribute of Conv2D and BiasAdd"
- "do not match. Will skip node merge optimization");
+ "data_format or T attribute or devices of Conv2D and "
+ "BiasAdd do not match. Will skip node merge optimization");
}
// 2. Get inputs from both the nodes.
// Find the 2 inputs from the conv and the bias from the add Bias.
Node* oper1 = nullptr;
+ Node* oper1_mkl = nullptr; // Mkl tensor corresponding to oper1
Node* oper2 = nullptr;
+ Node* oper2_mkl = nullptr; // Mkl tensor corresponding to oper2
Node* oper3 = nullptr;
+ Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
const int succ_num = succ->num_inputs();
gtl::InlinedVector<Node*, 4> succ_control_edges;
@@ -326,24 +354,35 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g,
}
}
- // Get operand 0, 1 of conv2D
- oper1 = pred_in[0].first;
- oper2 = pred_in[1].first;
+ // Get operand 0, 1 of conv2D and their Mkl tensors.
+ CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs.
+ oper1 = pred_in[0].first;
+ oper1_mkl = pred_in[1].first;
+ oper2 = pred_in[2].first;
+ oper2_mkl = pred_in[3].first;
// Get operand 1 of add_bias
- oper3 = succ_in[1].first;
+ // BiasAdd must have 2 inputs: Conv, bias
+ CHECK_EQ(succ->in_edges().size(), 2);
+ oper3 = succ_in[1].first;
+ GetDummyMklTensorNode(g, &oper3_mkl); // Get dummy Mkl tensor node
+ // as BiasAdd does not have Mkl tensor as input.
+ CHECK_NOTNULL(oper3_mkl);
Node* ret;
// We will use the node name of BiasAdd as the name of new node
TF_CHECK_OK(NodeBuilder(succ->name(), csinfo_.conv2dwithbias)
.Input(oper1)
+ .Input(oper1_mkl)
.Input(oper2)
+ .Input(oper2_mkl)
.Input(oper3)
+ .Input(oper3_mkl)
.Attr("T", T_pred)
.Attr("strides", strides)
.Attr("padding", padding)
.Attr("data_format", data_format_pred)
.Attr("use_cudnn_on_gpu", use_cudnn_on_gnu)
- .Attr("groups", groups)
+ .Device(succ->def().device())
.Finalize(&**g, &ret));
CHECK_NOTNULL(ret);
@@ -352,6 +391,15 @@ Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g,
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
}
+ // Copy device assigned to old node to new node.
+ // It's ok to use pred or succ as we have enforced a check that
+ // both have same device assigned.
+ ret->set_assigned_device_name(pred->assigned_device_name());
+
+ VLOG(1) << "NodeMergeRewritePass: Merged old node:" << pred->DebugString()
+ << ", and node: " << succ->DebugString() << ", into node:"
+ << ret->DebugString();
+
(*g)->RemoveNode(succ);
(*g)->RemoveNode(pred);
@@ -369,13 +417,14 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) {
const Node* fwdn = nullptr;
const RewriteInfo* ri = FindMatchingRewriteInfo(n, &fwdn);
if (ri == nullptr || fwdn == nullptr) {
- VLOG(1) << "Rewriteinfo not found for: " << n->type_string();
+ VLOG(2) << "NodeMergeRewritePass: Rewriteinfo not found for: "
+ << n->type_string();
return Status(error::Code::INVALID_ARGUMENT,
"Rewrite info not found for the node."
"Will skip node rewrite optimization");
}
- VLOG(1) << "Rewrite called for: " << n->type_string();
+ VLOG(1) << "NodeMergeRewritePass: Rewrite called for: " << n->type_string();
if (n->type_string() == csinfo_.biasaddgrad &&
ri->node == csinfo_.biasaddgrad &&
@@ -407,6 +456,7 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) {
.Attr("T", T)
.Attr("data_format", data_format)
.Attr("strides", strides)
+ .Device(n->def().device())
.Finalize(&**g, &ret));
} else {
CHECK_EQ(ri->rewrite, csinfo_.biasaddgrad);
@@ -414,6 +464,7 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) {
.Input(op)
.Attr("T", T)
.Attr("data_format", data_format)
+ .Device(n->def().device())
.Finalize(&**g, &ret));
}
@@ -424,7 +475,11 @@ Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node *n) {
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
}
- VLOG(1) << "Rewrite node: " << n->type_string() << " successful";
+ // Copy device assigned to old node to new node.
+ ret->set_assigned_device_name(n->assigned_device_name());
+
+ VLOG(1) << "MKLOptimizerMergePass: Rewrote old node:" << n->DebugString()
+ << ", into node:" << ret->DebugString();
(*g)->RemoveNode(n);
return Status::OK();
@@ -450,7 +505,8 @@ NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n,
}
}
- VLOG(1) << "Searching graph for: " << n->type_string() << " in backwards.";
+ VLOG(1) << "NodeMergeRewritePass: Searching graph for: "
+ << n->type_string() << " in backwards.";
// Now we will check for forward op name for rewrite info in data
// flow graph. Get the max hops we should search for the fwd node
@@ -473,7 +529,8 @@ NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n,
curr_depth = curr_pair.second;
CHECK_NOTNULL(curr_node);
- VLOG(1) << "Visiting node: " << curr_node->type_string()
+ VLOG(1) << "NodeMergeRewritePass: Visiting node: "
+ << curr_node->type_string()
<< " at depth: " << curr_depth
<< " for node: " << n->type_string();
@@ -528,17 +585,16 @@ bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
std::vector<std::pair<Node*, Node*>> nodes_to_be_merged;
std::vector<Node*> nodes_to_be_rewritten;
- VLOG(1) << "Running NodeMerge Optimization";
-
for (Node* n : order) {
if (!n->IsOp()) continue;
Node* n1 = nullptr;
if ((n1 = FindNodeForMerge(n)) != nullptr) {
- VLOG(1) << "Scheduled nodes " << n->name() << " and "
- << n1->name() << " for merging";
+ VLOG(1) << "NodeMergeRewritePass: Scheduled nodes "
+ << n->name() << " and " << n1->name() << " for merging";
nodes_to_be_merged.push_back(std::make_pair(n, n1));
} else if (IsApplicableRewriteNode(n)) {
- VLOG(1) << "Scheduled node " << n->name() << " for rewrite";
+ VLOG(1) << "NodeMergeRewritePass: Scheduled node " << n->name()
+ << " for rewrite";
nodes_to_be_rewritten.push_back(n);
}
}
@@ -549,7 +605,8 @@ bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
string n1_name = i.first->name();
string n2_name = i.second->name();
if (MergeNode(g, i.first, i.second) == Status::OK()) {
- VLOG(1) << "Merged nodes " << n1_name << " and " << n2_name;
+ VLOG(1) << "NodeMergeRewritePass: Merged nodes " << n1_name
+ << " and " << n2_name;
result = true;
}
}
@@ -559,7 +616,8 @@ bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
for (Node* i : nodes_to_be_rewritten) {
string name = i->name();
if (RewriteNode(g, i) == Status::OK()) {
- VLOG(1) << "Rewrite node: " << name << " successful.";
+ VLOG(1) << "NodeMergeRewritePass: Rewrite node: "
+ << name << " successful.";
result = true;
}
}
@@ -574,8 +632,6 @@ bool OptimizeNodeMerge(std::unique_ptr<Graph>* g) {
}
Status NodeMergeRewritePass::Run(const GraphOptimizationPassOptions& options) {
- // Currently checking only for two cases - Conv2D+Bias and Matmul+Bias.
- // It is possible to extend it to other operators in future.
if (options.graph == nullptr) {
return Status::OK();
}