diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass_test.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass_test.cc | 36 |
1 files changed, 21 insertions, 15 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index 0a63cf6ddb..7d9237f845 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -17,7 +17,10 @@ limitations under the License. #include "tensorflow/core/graph/mkl_tfconversion_pass.h" +#include <algorithm> +#include <string> #include <vector> + #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph.h" @@ -146,31 +149,34 @@ TEST_F(MklToTfConversionPass, Positive) { "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3"); } -// MklConv2D followed by Non-Mkl layer, and MklConv2D uses half type -// C=MklConv2D(A,M,B,N); E=Sub(C,D) -// MklToTf node should be inserted. -TEST_F(MklToTfConversionPass, Positive_Type) { +// MklConv2D followed by MklToTf op followed by Non-Mkl layer. +// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) +// MklToTf node should not be inserted again. +TEST_F(MklToTfConversionPass, Negative_DoubleInsert) { InitGraph( - "node { name: 'A' op: 'HalfInput'}" + "node { name: 'A' op: 'Input'}" "node { name: 'M' op: 'MklInput'}" - "node { name: 'B' op: 'HalfInput'}" + "node { name: 'B' op: 'Input'}" "node { name: 'N' op: 'MklInput'}" "node { name: 'C' op: 'MklConv2D'" - " attr { key: 'T' value { type: DT_HALF } }" + " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " attr { key: 'use_cudnn_on_gpu' value { b: false } }" " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" " attr { key: 'padding' value { s: 'SAME' } }" " input: ['A', 'M', 'B', 'N']}" - "node { name: 'D' op: 'HalfInput'}" - "node { name: 'E' op: 'Sub'" - " attr {key: 'T' value { type: DT_HALF } }" - " input: ['C', 'D']}"); + "node { name: 'D' op: 'MklToTf'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['C:0', 'C:1']}" + "node { name: 'E' op: 'Input'}" + "node { name: 'F' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'E']}"); EXPECT_EQ(DoRunMklToTfConversionPass(), - "A(HalfInput);B(HalfInput);C(MklConv2D);D(HalfInput);" - "E(Sub);M(MklInput);Mkl2Tf/_0(MklToTf);N(MklInput)|" - "A->C;B->C:2;C->Mkl2Tf/_0;C:1->Mkl2Tf/_0:1;D->E:1;" - "M->C:1;Mkl2Tf/_0->E;N->C:3"); + "A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);" + "F(Sub);M(MklInput);N(MklInput)|" + "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3"); } // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y); |