diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass.cc | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 8c3adad6f0..7c3836b308 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -81,9 +81,10 @@ class MklToTfConversionPass : public GraphOptimizationPass { // Is the input Op supported by Mkl-specific layout? // // @input op_name string of the op + // @input T Datatype to use for checking input op // @return true if op is Mkl supported; false, otherwise. - inline bool IsMklSupportedOp(const string& op_name) const { - return mkl_layer_registry::IsMklLayer(op_name); + inline bool IsMklSupportedOp(const string& op_name, DataType T) const { + return mkl_layer_registry::IsMklLayer(op_name, T); } // Insert layout conversion node on the edge pointed by 'e' from graph 'g'. @@ -188,6 +189,13 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { continue; } + // We skip adding MklToTf on an edge between X->MklToTf or + // MklToTf->X, where X is any layer. + if (src->type_string().compare("MklToTf") == 0 || + dst->type_string().compare("MklToTf") == 0) { + continue; + } + VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: " << src->type_string() << " and " << dst->type_string(); @@ -202,8 +210,9 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { GetNodeAttr(dst->def(), "T", &dst_datatype); // Check if src with is Mkl-compliant, while dst is not Mkl-compliant. - if (IsMklSupportedOp(src->type_string()) && - !IsMklSupportedOp(dst->type_string())) { + + if (IsMklSupportedOp(src->type_string(), src_datatype) && + !IsMklSupportedOp(dst->type_string(), dst_datatype)) { VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name() << " and " << dst->name() << " for inserting conversion nodes"; candidate_edges.push_back(const_cast<Edge*>(e)); |