aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc58
1 files changed, 42 insertions, 16 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 89b23f22fd..55bc401b9d 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2456,9 +2456,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// NOTE: names are alphabetically sorted.
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAddN, AddNRewrite});
- rinfo_.push_back({csinfo_.add,
+ /* rinfo_.push_back({csinfo_.add,
mkl_op_registry::GetMklOpName(csinfo_.add),
- CopyAttrsDataType, AlwaysRewrite});
+ CopyAttrsDataType, AlwaysRewrite}); */
rinfo_.push_back({csinfo_.avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
CopyAttrsPooling, AlwaysRewrite});
@@ -3117,7 +3117,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0,
const_cast<const Node**>(&orig_input0)));
- CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
}
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
@@ -3382,8 +3384,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
// We use a tensor of shape {1} and value 0 to represent
// dummy float tensor. We need this as a dummy workspace tensor.
- // Workspace tensor has type float.
- const DataType dt = DataTypeToEnum<float>::v();
+ // Workspace tensor has type uint8.
+ const DataType dt = DataTypeToEnum<uint8>::v();
TensorProto proto;
proto.set_dtype(dt);
float zero[1] = {0};
@@ -3413,7 +3415,9 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
Node* orig_input0 = nullptr;
TF_CHECK_OK(orig_node->input_node(0,
const_cast<const Node**>(&orig_input0)));
- CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
}
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
@@ -3863,12 +3867,16 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
// node are already copied in BuildNode. We handle control edges now.
for (const Edge* e : pred->in_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
}
}
for (const Edge* e : succ->in_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
}
}
@@ -3876,14 +3884,18 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
// First, we will fix outgoing control edges from 'pred' node.
for (const Edge* e : pred->out_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
}
}
// Second, we will fix outgoing control and data edges from 'succ' node.
for (const Edge* e : succ->out_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
} else {
// BiasAdd has only 1 output (at slot 0) and merged node also has only 1
// output (at slot 0).
@@ -3966,12 +3978,16 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
// edges now.
for (const Edge* e : badd->in_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
}
}
for (const Edge* e : fltr->in_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
}
}
@@ -3987,7 +4003,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
for (const Edge* e : badd->out_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
} else {
CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
e->dst(), e->dst_input()));
@@ -3997,7 +4015,11 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
// Second, we will fix outgoing control and data edges from 'fltr' node.
for (const Edge* e : fltr->out_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
+ // We allow duplicate edge for this case since we already add control
+ // edge from new_node in line 3990. Line below could be adding same
+ // edge to same destination again. In such case, if we do not allow
+ // duplicate edge, then this call will fail.
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
} else {
CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
e->dst(), e->dst_input()));
@@ -4091,7 +4113,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
// already copied in BuildNode. We need to handle control edges now.
for (const Edge* e : orig_node->in_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true));
}
}
@@ -4104,7 +4128,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
// GetTensorDataIndex provides this mapping function.
for (const Edge* e : orig_node->out_edges()) {
if (e->IsControlEdge()) {
- CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst()));
+ // Allow duplicate while adding control edge as it would fail (return
+ // NULL) if we try to add duplicate edge.
+ CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true));
} else {
CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
e->src()->num_outputs()),