aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_tfconversion_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc36
1 files changed, 20 insertions, 16 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 55c280719c..590b3d030f 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -98,12 +98,13 @@ class MklToTfConversionPass : public GraphOptimizationPass {
Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
};
-// We register MklToTf insertion for phase 1 in post-partition grouping.
-// We register this pass after partitioning so that we get a complete
-// picture of inputs and outputs of the nodes in the graphs.
+// We register MklToTf insertion for phase 2 in post-partition grouping
+// because we register MklLayoutRewritePass for phase 1 in post-partition
+// grouping. We register this pass after partitioning so that we get a
+// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
-REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 1, MklToTfConversionPass);
+REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
@@ -121,10 +122,12 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
string data_format;
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
- TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype));
- if (src_datatype != dst_datatype) {
- string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
- " do not match. Will not insert" +
+ bool dst_dtype_found = GetNodeAttr(dst->def(), "T", &dst_datatype) ==
+ Status::OK();
+ // We compare source and destination datatypes only when both are found.
+ if (dst_dtype_found && (src_datatype != dst_datatype)) {
+ string err_msg = "T attribute of " + src->name() + " and " +
+ dst->name() + " do not match. Will not insert" +
" MklToTf node in such case.";
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
}
@@ -202,18 +205,19 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
<< src->type_string() << " and " << dst->type_string();
// Let's get source and destination data type.
- DataType src_datatype = DT_INVALID;
- if (GetNodeAttr(src->def(), "T", &src_datatype) != Status::OK()) {
- continue;
- }
// We cannot check datatype on destination node because destination node
// may not be Mkl node.
- DataType dst_datatype = DT_INVALID;
- GetNodeAttr(dst->def(), "T", &dst_datatype);
+ DataType src_datatype;
+ DataType dst_datatype;
+ bool src_is_mkl_op = (GetNodeAttr(src->def(), "T", &src_datatype) ==
+ Status::OK() &&
+ IsMklSupportedOp(src->type_string(), src_datatype));
+ bool dst_is_mkl_op = (GetNodeAttr(dst->def(), "T", &dst_datatype) ==
+ Status::OK() &&
+ IsMklSupportedOp(dst->type_string(), dst_datatype));
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
- if (IsMklSupportedOp(src->type_string(), src_datatype) &&
- !IsMklSupportedOp(dst->type_string(), dst_datatype)) {
+ if (src_is_mkl_op && !dst_is_mkl_op) {
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
<< " and " << dst->name() << " for inserting conversion nodes";
candidate_edges.push_back(const_cast<Edge*>(e));