aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 23:14:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 23:14:39 -0700
commit08a6cfed1cf0cccc8ff35448266f44fbc55be0bc (patch)
tree73f61074984cd9dcf05e5d65b454a6ce08484f4a /tensorflow/core/graph
parentd3f14ef70cdf113f9d330c1f7c638003429a1dc4 (diff)
parentd1ab8b71c2115caacfec19d849ddabf7f1f4287b (diff)
Merge pull request #22076 from Intel-tensorflow:feature/daoxin/slice
PiperOrigin-RevId: 214726180
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc19
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc20
2 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 37b88f1728..06d3fefef1 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2450,6 +2450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.tanh = "Tanh";
csinfo_.tanh_grad = "TanhGrad";
csinfo_.reshape = "Reshape";
+ csinfo_.slice = "Slice";
csinfo_.softmax = "Softmax";
csinfo_.split = "Split";
// Element-wise ops. Ensure you also add any new ops to IsOpElementWise
@@ -2557,6 +2558,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.slice,
+ mkl_op_registry::GetMklOpName(csinfo_.slice),
+ CopyAttrsSlice, AlwaysRewrite});
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsDataType, AlwaysRewrite});
@@ -2676,6 +2680,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string tanh;
string tanh_grad;
string reshape;
+ string slice;
string softmax;
string split;
string squared_difference;
@@ -3134,6 +3139,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
@@ -3739,6 +3745,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
nb->Attr("Tshape", Tshape);
}
+void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ DataType Index;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("Index", Index);
+}
+
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index f42a4ee98b..77640e287c 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Int32Input'}"
+ "node { name: 'D' op: 'Slice'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Index' value { type: DT_INT32 } }"
+ " input: ['A', 'B', 'C'] }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Int32Input);"
+ "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/"
+ "_2(Const);E(Zeta)|A->D;A->E;"
+ "A:control->DMT/_0:control;A:control->DMT/"
+ "_1:control;A:control->DMT/_2:control;"
+ "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test