diff options
author | 2018-08-16 13:04:52 -0700 | |
---|---|---|
committer | 2018-08-16 13:05:05 -0700 | |
commit | 9c50882415cb87a7eb81048d42401c64bf0617ef (patch) | |
tree | c550925b2d9e7f6997ace0e3bb3268572f7066b7 /tensorflow/core/graph | |
parent | 19cafed2ae69ce5cbc4d2b2fc9176fb4c550040f (diff) | |
parent | 62191da0819b25906c1b2ed96159cfe36ba00383 (diff) |
Merge pull request #21324 from Intel-tensorflow:conv3d
PiperOrigin-RevId: 209032082
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 39 |
1 files changed, 25 insertions, 14 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 5683944e46..833592caab 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2418,6 +2418,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter"; csinfo_.conv2d_grad_filter_with_bias = "__MklDummyConv2DBackpropFilterWithBias"; + csinfo_.conv3d = "Conv3D"; + csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2"; + csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2"; csinfo_.fused_batch_norm = "FusedBatchNorm"; csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad"; csinfo_.identity = "Identity"; @@ -2468,18 +2471,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsConcatV2, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, - csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D, + csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.conv2d_grad_input, mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), - CopyAttrsConv2D, AlwaysRewrite}); + CopyAttrsConv, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv3d, + mkl_op_registry::GetMklOpName(csinfo_.conv3d), + CopyAttrsConv, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv3d_grad_filter, + mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter), + CopyAttrsConv, AlwaysRewrite}); + rinfo_.push_back({csinfo_.conv3d_grad_input, + mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input), + CopyAttrsConv, AlwaysRewrite}); rinfo_.push_back({csinfo_.fused_batch_norm, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), CopyAttrsFusedBatchNorm, AlwaysRewrite}); @@ -2614,6 +2626,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string conv2d_grad_input; string conv2d_grad_filter; string conv2d_grad_filter_with_bias; + string conv3d; + string conv3d_grad_input; + string conv3d_grad_filter; string fused_batch_norm; string fused_batch_norm_grad; string identity; @@ -3086,7 +3101,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb); - static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); @@ -3571,14 +3586,13 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( // Op-specific functions to copy attributes from old node to new node ////////////////////////////////////////////////////////////////////////// -void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, - NodeBuilder* nb) { +void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, + NodeBuilder* nb) { DataType T; string data_format; string padding; std::vector<int32> strides; std::vector<int32> dilations; - bool use_cudnn_on_gpu; // Get all attributes from old node. TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); @@ -3586,8 +3600,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format)); - TF_CHECK_OK( - GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu)); // Add attributes to new node. nb->Attr("T", T); @@ -3595,7 +3607,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node, nb->Attr("dilations", dilations); nb->Attr("padding", padding); nb->Attr("data_format", data_format); - nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu); } void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node, @@ -3896,7 +3907,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd // Copy attributes from Conv2D to Conv2DWithBias. - CopyAttrsConv2D(const_cast<const Node*>(pred), &nb); + CopyAttrsConv(const_cast<const Node*>(pred), &nb); // Copy the device assigned to old node to new node. nb.Device(succ->def().device()); @@ -4007,7 +4018,7 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( } // Copy attributes from Conv2DBackpropFilter. - CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb); + CopyAttrsConv(const_cast<const Node*>(fltr), &nb); // Copy the device assigned to old node to new node. nb.Device(fltr->def().device()); |