diff options
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass.cc')
-rw-r--r-- | tensorflow/core/graph/mkl_tfconversion_pass.cc | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 55c280719c..590b3d030f 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -98,12 +98,13 @@ class MklToTfConversionPass : public GraphOptimizationPass { Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*); }; -// We register MklToTf insertion for phase 1 in post-partition grouping. -// We register this pass after partitioning so that we get a complete -// picture of inputs and outputs of the nodes in the graphs. +// We register MklToTf insertion for phase 2 in post-partition grouping +// because we register MklLayoutRewritePass for phase 1 in post-partition +// grouping. We register this pass after partitioning so that we get a +// complete picture of inputs and outputs of the nodes in the graphs. const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = OptimizationPassRegistry::POST_PARTITIONING; -REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 1, MklToTfConversionPass); +REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass); Status MklToTfConversionPass::InsertConversionNodeOnEdge( std::unique_ptr<Graph>* g, Edge* e) { @@ -121,10 +122,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge( string data_format; TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype)); - TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype)); - if (src_datatype != dst_datatype) { - string err_msg = "T attribute of " + src->name() + " and " + dst->name() + - " do not match. Will not insert" + + bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) == + Status::OK(); + // We compare source and destination datatypes only when both are found. + if (dst_dtype_found && (src_datatype != dst_datatype)) { + string err_msg = "T attribute of " + src->name() + " and " + + dst->name() + " do not match. Will not insert" + " MklToTf node in such case."; return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); } @@ -202,18 +205,19 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { << src->type_string() << " and " << dst->type_string(); // Let's get source and destination data type. - DataType src_datatype = DT_INVALID; - if (GetNodeAttr(src->def(), "T", &src_datatype) != Status::OK()) { - continue; - } // We cannot check datatype on destination node because destination node // may not be Mkl node. - DataType dst_datatype = DT_INVALID; - GetNodeAttr(dst->def(), "T", &dst_datatype); + DataType src_datatype; + DataType dst_datatype; + bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) == + Status::OK() && + IsMklSupportedOp(src->type_string(), src_datatype)); + bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) == + Status::OK() && + IsMklSupportedOp(dst->type_string(), dst_datatype)); // Check if src with is Mkl-compliant, while dst is not Mkl-compliant. - if (IsMklSupportedOp(src->type_string(), src_datatype) && - !IsMklSupportedOp(dst->type_string(), dst_datatype)) { + if (src_is_mkl_op && !dst_is_mkl_op) { VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name() << " and " << dst->name() << " for inserting conversion nodes"; candidate_edges.push_back(const_cast<Edge*>(e)); |