diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 58 |
1 files changed, 42 insertions, 16 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 89b23f22fd..55bc401b9d 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2456,9 +2456,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // NOTE: names are alphabetically sorted. rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, AddNRewrite}); - rinfo_.push_back({csinfo_.add, + /* rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), - CopyAttrsDataType, AlwaysRewrite}); + CopyAttrsDataType, AlwaysRewrite}); */ rinfo_.push_back({csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsPooling, AlwaysRewrite}); @@ -3117,7 +3117,9 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node* orig_input0 = nullptr; TF_CHECK_OK(orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); - CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); } (*out)->set_assigned_device_name(orig_node->assigned_device_name()); @@ -3382,8 +3384,8 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( std::unique_ptr<Graph>* g, Node** out, Node* orig_node) { // We use a tensor of shape {1} and value 0 to represent // dummy float tensor. We need this as a dummy workspace tensor. - // Workspace tensor has type float. - const DataType dt = DataTypeToEnum<float>::v(); + // Workspace tensor has type uint8. + const DataType dt = DataTypeToEnum<uint8>::v(); TensorProto proto; proto.set_dtype(dt); float zero[1] = {0}; @@ -3413,7 +3415,9 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( Node* orig_input0 = nullptr; TF_CHECK_OK(orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); - CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out)); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true)); } (*out)->set_assigned_device_name(orig_node->assigned_device_name()); @@ -3863,12 +3867,16 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, // node are already copied in BuildNode. We handle control edges now. for (const Edge* e : pred->in_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); } } for (const Edge* e : succ->in_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); } } @@ -3876,14 +3884,18 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, // First, we will fix outgoing control edges from 'pred' node. for (const Edge* e : pred->out_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); } } // Second, we will fix outgoing control and data edges from 'succ' node. for (const Edge* e : succ->out_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); } else { // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 // output (at slot 0). @@ -3966,12 +3978,16 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( // edges now. for (const Edge* e : badd->in_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); } } for (const Edge* e : fltr->in_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); } } @@ -3987,7 +4003,9 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( for (const Edge* e : badd->out_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); } else { CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx, e->dst(), e->dst_input())); @@ -3997,7 +4015,11 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( // Second, we will fix outgoing control and data edges from 'fltr' node. for (const Edge* e : fltr->out_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + // We allow duplicate edge for this case since we already add control + // edge from new_node in line 3990. Line below could be adding same + // edge to same destination again. In such case, if we do not allow + // duplicate edge, then this call will fail. + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); } else { CHECK_NOTNULL((*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx, e->dst(), e->dst_input())); @@ -4091,7 +4113,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, // already copied in BuildNode. We need to handle control edges now. for (const Edge* e : orig_node->in_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node)); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(e->src(), new_node, true)); } } @@ -4104,7 +4128,9 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, // GetTensorDataIndex provides this mapping function. for (const Edge* e : orig_node->out_edges()) { if (e->IsControlEdge()) { - CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst())); + // Allow duplicate while adding control edge as it would fail (return + // NULL) if we try to add duplicate edge. + CHECK_NOTNULL((*g)->AddControlEdge(new_node, e->dst(), true)); } else { CHECK_NOTNULL((*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), |