aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass_test.cc')
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc36
1 files changed, 21 insertions, 15 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index 0a63cf6ddb..7d9237f845 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -17,7 +17,10 @@ limitations under the License.
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
+#include <algorithm>
+#include <string>
#include <vector>
+
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
@@ -146,31 +149,34 @@ TEST_F(MklToTfConversionPass, Positive) {
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
}
-// MklConv2D followed by Non-Mkl layer, and MklConv2D uses half type
-// C=MklConv2D(A,M,B,N); E=Sub(C,D)
-// MklToTf node should be inserted.
-TEST_F(MklToTfConversionPass, Positive_Type) {
+// MklConv2D followed by MklToTf op followed by Non-Mkl layer.
+// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E)
+// MklToTf node should not be inserted again.
+TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
InitGraph(
- "node { name: 'A' op: 'HalfInput'}"
+ "node { name: 'A' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
- "node { name: 'B' op: 'HalfInput'}"
+ "node { name: 'B' op: 'Input'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
- " attr { key: 'T' value { type: DT_HALF } }"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'M', 'B', 'N']}"
- "node { name: 'D' op: 'HalfInput'}"
- "node { name: 'E' op: 'Sub'"
- " attr {key: 'T' value { type: DT_HALF } }"
- " input: ['C', 'D']}");
+ "node { name: 'D' op: 'MklToTf'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'data_format' value { s: 'NCHW' } }"
+ " input: ['C:0', 'C:1']}"
+ "node { name: 'E' op: 'Input'}"
+ "node { name: 'F' op: 'Sub'"
+ " attr {key: 'T' value { type: DT_FLOAT } }"
+ " input: ['D', 'E']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
- "A(HalfInput);B(HalfInput);C(MklConv2D);D(HalfInput);"
- "E(Sub);M(MklInput);Mkl2Tf/_0(MklToTf);N(MklInput)|"
- "A->C;B->C:2;C->Mkl2Tf/_0;C:1->Mkl2Tf/_0:1;D->E:1;"
- "M->C:1;Mkl2Tf/_0->E;N->C:3");
+ "A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
+ "F(Sub);M(MklInput);N(MklInput)|"
+ "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);