aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/mkl_tfconversion_pass.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph/mkl_tfconversion_pass.cc')
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc17
1 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 8c3adad6f0..7c3836b308 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -81,9 +81,10 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// Is the input Op supported by Mkl-specific layout?
//
// @input op_name string of the op
+ // @input T Datatype to use for checking input op
// @return true if op is Mkl supported; false, otherwise.
- inline bool IsMklSupportedOp(const string& op_name) const {
- return mkl_layer_registry::IsMklLayer(op_name);
+ inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
+ return mkl_layer_registry::IsMklLayer(op_name, T);
}
// Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
@@ -188,6 +189,13 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
continue;
}
+ // We skip adding MklToTf on an edge between X->MklToTf or
+ // MklToTf->X, where X is any layer.
+ if (src->type_string().compare("MklToTf") == 0 ||
+ dst->type_string().compare("MklToTf") == 0) {
+ continue;
+ }
+
VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: "
<< src->type_string() << " and " << dst->type_string();
@@ -202,8 +210,9 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
GetNodeAttr(dst->def(), "T", &dst_datatype);
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
- if (IsMklSupportedOp(src->type_string()) &&
- !IsMklSupportedOp(dst->type_string())) {
+
+ if (IsMklSupportedOp(src->type_string(), src_datatype) &&
+ !IsMklSupportedOp(dst->type_string(), dst_datatype)) {
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
<< " and " << dst->name() << " for inserting conversion nodes";
candidate_edges.push_back(const_cast<Edge*>(e));