diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 23:14:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 23:14:39 -0700 |
commit | 08a6cfed1cf0cccc8ff35448266f44fbc55be0bc (patch) | |
tree | 73f61074984cd9dcf05e5d65b454a6ce08484f4a /tensorflow/core/graph | |
parent | d3f14ef70cdf113f9d330c1f7c638003429a1dc4 (diff) | |
parent | d1ab8b71c2115caacfec19d849ddabf7f1f4287b (diff) |
Merge pull request #22076 from Intel-tensorflow:feature/daoxin/slice
PiperOrigin-RevId: 214726180
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 19 | ||||
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass_test.cc | 20 |
2 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 37b88f1728..06d3fefef1 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2450,6 +2450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.tanh = "Tanh"; csinfo_.tanh_grad = "TanhGrad"; csinfo_.reshape = "Reshape"; + csinfo_.slice = "Slice"; csinfo_.softmax = "Softmax"; csinfo_.split = "Split"; // Element-wise ops. Ensure you also add any new ops to IsOpElementWise @@ -2557,6 +2558,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsReshape, AlwaysRewrite}); + rinfo_.push_back({csinfo_.slice, + mkl_op_registry::GetMklOpName(csinfo_.slice), + CopyAttrsSlice, AlwaysRewrite}); rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), CopyAttrsDataType, AlwaysRewrite}); @@ -2676,6 +2680,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string tanh; string tanh_grad; string reshape; + string slice; string softmax; string split; string squared_difference; @@ -3134,6 +3139,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, @@ -3739,6 +3745,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, nb->Attr("Tshape", Tshape); } +void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + DataType Index; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index)); + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("Index", Index); +} + void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb) { DataType T; diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index f42a4ee98b..77640e287c 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); } +TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Int32Input'}" + "node { name: 'C' op: 'Int32Input'}" + "node { name: 'D' op: 'Slice'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'Index' value { type: DT_INT32 } }" + " input: ['A', 'B', 'C'] }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Int32Input);C(Int32Input);" + "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/" + "_2(Const);E(Zeta)|A->D;A->E;" + "A:control->DMT/_0:control;A:control->DMT/" + "_1:control;A:control->DMT/_2:control;" + "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); +} + ///////////////////////////////////////////////////////////////////// // Post-rewrite fixup pass test |