aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_layout_pass_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_layout_pass_test.cc')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc720
1 files changed, 661 insertions, 59 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 6e72baf84e..3c4a5263af 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -39,7 +39,11 @@ limitations under the License.
namespace tensorflow {
namespace {
-static void InitGraph(const string& s, Graph* graph) {
+const char kCPUDevice[] = "/job:a/replica:0/task:0/cpu:0";
+const char kGPUDevice[] = "/job:a/replica:0/task:0/gpu:0";
+
+static void InitGraph(const string& s, Graph* graph,
+ const string& device = kCPUDevice) {
GraphDef graph_def;
auto parser = protobuf::TextFormat::Parser();
@@ -47,14 +51,18 @@ static void InitGraph(const string& s, Graph* graph) {
CHECK(parser.MergeFromString(s, &graph_def)) << s;
GraphConstructorOptions opts;
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
+
+ for (Node* node : graph->nodes()) {
+ node->set_assigned_device_name(device);
+ }
}
class MklLayoutPassTest : public ::testing::Test {
public:
MklLayoutPassTest() : graph_(OpRegistry::Global()) {}
- void InitGraph(const string& s) {
- ::tensorflow::InitGraph(s, &graph_);
+ void InitGraph(const string& s, const string& device = kCPUDevice) {
+ ::tensorflow::InitGraph(s, &graph_, device);
original_ = CanonicalGraphString(&graph_);
}
@@ -114,7 +122,8 @@ REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
-REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
+REGISTER_OP("_MklInput2").Output("o: uint8")
+ .Output("o1: uint8").SetIsStateful();
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optiimization
@@ -162,8 +171,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
- "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
- "DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
+ "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;"
+ "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;M->E:3;"
+ "N->E:4;Y->Z:1");
}
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
@@ -194,8 +204,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
- "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
- "DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
+ "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;"
+ "A:control->DMT/_0:control;B->E:1;D->E:2;DMT/_0->E:5;E->Z;"
+ "M:1->E:3;N:1->E:4;Y->Z:1");
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
@@ -226,8 +237,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
- "A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
- "E->Z;Y->Z:1");
+ "A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "A:control->DMT/_2:control;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;"
+ "DMT/_2->E:5;E->Z;Y->Z:1");
}
// Graph contains only _MklConv2D, no AddBias.
@@ -330,9 +342,6 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
}
-// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
-// of BiasAddGrad into BackpropBias
-#if 0
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
// rewrite tests
@@ -361,18 +370,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
- "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
- "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
- "M->D:3;N->D:4;O->D:5");
+ "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);"
+ "N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;"
+ "DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;"
+ "O->D:5");
}
-#endif
-// No _MklConv2D in context, but Conv2D in context.
-// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
-// for BiasAddGrad should happen.
+// No _MklConv2DWithBias in context, but _MklConv2D in context.
+// No rewrite for BiasAddGrad should happen.
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
@@ -507,8 +515,10 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
- "A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
+ "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);"
+ "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
+ "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
+ "DMT/_1->C:3");
}
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
@@ -535,7 +545,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
+ "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;"
+ "A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "A:control->DMT/_2:control;B->C:1;C->D:1;C->E;"
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
}
@@ -558,6 +570,50 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
"A->C;B->C:1;B->D;C->D:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Conv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'C']}"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropFilter);"
+ "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
+ "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "A:control->DMT/_2:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
+ "DMT/_1->D:4;DMT/_2->D:5");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradInput_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Conv2DBackpropInput'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['B', 'A', 'C']}"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Input);D(_MklConv2DBackpropInput);"
+ "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Mul)|"
+ "A->D:1;A->E;B->D;B:control->DMT/_0:control;"
+ "B:control->DMT/_1:control;B:control->DMT/_2:control;C->D:2;"
+ "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
// Concat Op test: Concat with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
InitGraph(
@@ -572,13 +628,14 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
"node { name: 'D' op: 'Concat'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'N' value { i: 2 } }"
- " input: ['A', 'B']}"
+ " input: ['A', 'B:0', 'B:1']}"
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
- "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+ "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;A:control->DMT/_0:control;"
+ "A:control->DMT/_1:control;A:control->DMT/_2:control;B->D:1;"
+ "B:1->D:2;C->E;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Concat with 2 Mkl layers feeding it
@@ -616,9 +673,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
- "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
+ "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;"
+ "A:control->DMT/_2:control;A:control->DMT/_3:control;"
+ "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
- "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
+ "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;"
+ "G:control->DMT/_4:control;H->I:1");
}
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
@@ -651,12 +711,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
- "H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
+ "H(_MklConcat);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
+ "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
- "G->H;H->I:1");
+ "G->H;G:control->DMT/_2:control;G:control->DMT/_3:control;H->I:1");
}
-#if 0
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
InitGraph(
@@ -676,11 +736,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
- "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
- "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
- "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+ "A(Const);B(InputList);C(Input);D(_MklConcatV2);DMT/_0(Const);"
+ "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;"
+ "B:control->DMT/_0:control;B:control->DMT/_1:control;"
+ "B:control->DMT/_2:control;C->E;D->E:1;DMT/_0->D:3;"
+ "DMT/_1->D:4;DMT/_2->D:5");
}
-#endif
// ConcatV2 with 2 Mkl layers feeding it
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
@@ -718,9 +779,12 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
- "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
+ "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;"
+ "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
+ "C:control->DMT/_0:control;C:control->DMT/_1:control;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
- "DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
+ "DMT/_4->H:5;E->H;E:1->H:3;E:control->DMT/_4:control;F->H:1;"
+ "F:1->H:4;G->H:2;H->I:1");
}
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
@@ -754,11 +818,175 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
- "H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
- "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
+ "H(_MklConcatV2);I(Mul)|A->E;A->I;A:control->DMT/_0:control;"
+ "A:control->DMT/_1:control;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
+ "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;"
+ "E:control->DMT/_2:control;E:control->DMT/_3:control;F->H:1;"
"G->H:2;H->I:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Relu_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Relu'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklRelu);C(Mul);DMT/_0(Const)|A->B;A->C;"
+ "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'ReluGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'C'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
+ "DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;"
+ "A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_ReluReluGrad_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Relu'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'ReluGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'C'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklRelu);C(_MklReluGrad);D(Mul);DMT/_0(Const);"
+ "DMT/_1(Const)|A->B;A->C;A->D;A:control->DMT/_0:control;"
+ "A:control->DMT/_1:control;B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;"
+ "DMT/_1->C:2");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'AvgPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklAvgPool);C(Mul);DMT/_0(Const)|A->B;A->C;"
+ "A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolGrad_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Int32Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'AvgPoolGrad' "
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'C'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Int32Input);B(Input);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
+ "DMT/_1(Const)|A->C;A:control->DMT/_0:control;"
+ "A:control->DMT/_1:control;B->C:1;B->D;C->D:1;DMT/_0->C:2;"
+ "DMT/_1->C:3");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_AvgPoolAvgPoolGrad_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'I' op: 'Int32Input'}"
+ "node { name: 'B' op: 'AvgPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'AvgPoolGrad' "
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
+ " input: ['I', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'C'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(_MklAvgPool);C(_MklAvgPoolGrad);D(Mul);DMT/_0(Const);"
+ "DMT/_1(Const);I(Int32Input)|A->B;A->D;A:control->DMT/_0:control;"
+ "B->C:1;B:1->C:3;C->D:1;DMT/_0->B:1;DMT/_1->C:2;I->C;"
+ "I:control->DMT/_1:control");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormGrad_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Input'}"
+ "node { name: 'F' op: 'FusedBatchNormGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'epsilon' value { f: 0.0001 } }"
+ " attr { key: 'is_training' value { b: true } }"
+ " input: ['A', 'B', 'C', 'D', 'E'] }"
+ "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'F'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
+ "F(_MklFusedBatchNormGrad);G(Mul)|A->F;A->G;"
+ "A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "A:control->DMT/_2:control;A:control->DMT/_3:control;"
+ "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
+ "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
+ "E->F:4;F->G:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_Positive) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Input'}"
+ "node { name: 'F' op: 'FusedBatchNorm'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'epsilon' value { f: 0.0001 } }"
+ " attr { key: 'is_training' value { b: true } }"
+ " input: ['A', 'B', 'C', 'D', 'E'] }"
+ "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'F'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+ "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Input);"
+ "F(_MklFusedBatchNorm);G(Mul)|A->F;A->G;"
+ "A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "A:control->DMT/_2:control;A:control->DMT/_3:control;"
+ "A:control->DMT/_4:control;B->F:1;C->F:2;D->F:3;"
+ "DMT/_0->F:5;DMT/_1->F:6;DMT/_2->F:7;DMT/_3->F:8;DMT/_4->F:9;"
+ "E->F:4;F->G:1");
+}
+
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node for workspace edges
/////////////////////////////////////////////////////////////////////
@@ -802,13 +1030,13 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['H', 'G'] }");
- EXPECT_EQ(
- DoMklLayoutOptimizationPass(),
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
- "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
- "A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
- "C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
- "DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
+ "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);"
+ "I(Mul)|A->B;A:control->DMT/_0:control;B->C;B->E;B->G:2;B:1->G:3;"
+ "B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;B:control->DMT/_1:control;C->E:1;"
+ "C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;DMT/_2->G:5;"
+ "E->G;E:1->G:4;E:control->DMT/_2:control;F->G:1;G->I:1;H->I");
}
/* Test LRN->LRNGrad replacement by workspace nodes. */
@@ -838,8 +1066,9 @@ TEST_F(MklLayoutPassTest, LRN_Positive) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
- "A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
- "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
+ "A->B;A:control->DMT/_0:control;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;"
+ "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
+ "D->E:1;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
}
/* Test LRN->LRNGrad replacement when only one of them is present. */
@@ -858,7 +1087,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) {
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
- "A->B;A->C;B->C:1;DMT/_0->B:1");
+ "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
/* Test LRN->LRNGrad replacement when only one of them is present. */
@@ -880,8 +1109,10 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
- "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
- "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
+ "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "A:control->DMT/_2:control;A:control->DMT/_3:control;"
+ "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
+ "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
}
/* Test LRN->LRNGrad negative case, where single LRN feeds
@@ -920,9 +1151,13 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
- "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
- "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
- "D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
+ "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;"
+ "A:control->DMT/_0:control;B->E:2;"
+ "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
+ "C:control->DMT/_1:control;C:control->DMT/_2:control;"
+ "C:control->DMT/_3:control;C:control->DMT/_4:control;"
+ "C:control->DMT/_5:control;C:control->DMT/_6:control;"
+ "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
}
@@ -951,8 +1186,9 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
- "A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
- "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
+ "A->B;A:control->DMT/_0:control;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;"
+ "C->E;C->F;C:control->DMT/_1:control;C:control->DMT/_2:control;"
+ "D->E:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
}
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
@@ -972,7 +1208,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
- "A->B;A->C;B->C:1;DMT/_0->B:1");
+ "A->B;A->C;A:control->DMT/_0:control;B->C:1;DMT/_0->B:1");
}
// Test MaxPoolGrad replacement when only one of them is present.
@@ -995,8 +1231,374 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
- "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
- "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
+ "A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "A:control->DMT/_2:control;A:control->DMT/_3:control;"
+ "A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
+ "DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
+}
+
+// Test MaxPool handling for batch-wise pooling (NCHW)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative3) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Test MaxPool handling for batch-wise pooling (NCHW)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative4) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Test MaxPool handling for depth-wise pooling (NHWC)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative5) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:2, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Test MaxPool handling for depth-wise pooling (NCHW)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative6) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:2, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Test MaxPool handling for batch-wise pooling (NHWC)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative7) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NHWC' } }"
+ " attr { key: 'ksize' value { list: {i: 2, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Test MaxPool handling for batch-wise pooling (NHWC)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative8) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NHWC' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 2, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Test MaxPool handling for depth-wise pooling (NHWC)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative9) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NHWC' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:2} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Test MaxPool handling for depth-wise pooling (NHWC)
+// No rewrite should take place in such case
+TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative10) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NHWC' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:2} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+/////////////////////////////////////////////////////////////////////
+
+// Single Conv2D Op on GPU device
+// No rewrite should happen
+TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Conv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B']}"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['B', 'C'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Conv2D);D(Mul)|A->C;B->C:1;B->D;C->D:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'O' op: '_MklInput'}"
+ "node { name: 'D' op: '_MklConv2DWithBias'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
+ "node { name: 'E' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'A']}"
+ "node { name: 'F' op: 'BiasAddGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['E'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);"
+ "E(Sub);F(BiasAddGrad);M(_MklInput);N(_MklInput);"
+ "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;E->F;"
+ "M->D:3;N->D:4;O->D:5");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Conv2DBackpropFilter'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'C']}"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Input);D(Conv2DBackpropFilter);E(Mul)|"
+ "A->D;A->E;B->D:1;C->D:2;D->E:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_Relu_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Relu'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Relu);C(Mul)|A->B;A->C;B->C:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_ReluGrad_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'ReluGrad'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }"
+ "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'C'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(ReluGrad);D(Mul)|A->C;A->D;B->C:1;C->D:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_MaxPool_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'MaxPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NHWC' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(MaxPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_AvgPool_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'AvgPool'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NHWC' } }"
+ " attr { key: 'ksize' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'VALID' } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " input: ['A'] }"
+ "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'B'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(AvgPool);C(Mul)|A->B;A->C;B->C:1");
+}
+
+// Concat Op test: Concat with no Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_Concat_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'InputList'"
+ " attr { key: 'N' value { i: 2 } }}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Concat'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['A', 'B:0', 'B:1']}"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Const);B(InputList);C(Input);D(Concat);E(Mul)|A->D;"
+ "B->D:1;B:1->D:2;C->E;D->E:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Const' "
+ " attr { key: 'dtype' value { type: DT_INT32 } }"
+ " attr { key: 'value' value { "
+ " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+ " int_val: 0 } } } }"
+ "node { name: 'B' op: 'InputList'"
+ " attr { key: 'N' value { i: 2 } }}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'ConcatV2'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Tidx' value { type: DT_INT32 } }"
+ " attr { key: 'N' value { i: 2 } }"
+ " input: ['B:0', 'B:1', 'A']}"
+ "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['C', 'D'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Const);B(InputList);C(Input);D(ConcatV2);E(Mul)|"
+ "A->D:2;B->D;B:1->D:1;C->E;D->E:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'C' op: 'Input'}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'Input'}"
+ "node { name: 'F' op: 'FusedBatchNorm'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'epsilon' value { f: 0.0001 } }"
+ " attr { key: 'is_training' value { b: true } }"
+ " input: ['A', 'B', 'C', 'D', 'E'] }"
+ "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'F'] }", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(Input);D(Input);E(Input);"
+ "F(FusedBatchNorm);G(Mul)|A->F;A->G;B->F:1;C->F:2;D->F:3;"
+ "E->F:4;F->G:1");
+}
+
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
+ CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Input'}"
+ "node { name: 'M' op: '_MklInput'}"
+ "node { name: 'N' op: '_MklInput'}"
+ "node { name: 'C' op: '_MklConv2D'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+ " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
+ " attr { key: 'padding' value { s: 'SAME' } }"
+ " input: ['A', 'B', 'M', 'N']}"
+ "node { name: 'D' op: 'Input'}"
+ "node { name: 'E' op: 'BiasAdd'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['C', 'D'] }"
+ "node { name: 'Y' op: 'Input'}"
+ "node { name: 'Z' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['E', 'Y']}", kGPUDevice);
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);"
+ "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->C;"
+ "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
/////////////////////////////////////////////////////////////////////