aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc17
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc6
-rw-r--r--tensorflow/core/graph/testlib.cc10
-rw-r--r--tensorflow/core/graph/testlib.h4
4 files changed, 24 insertions, 13 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 68c3136019..7d3be15299 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -42,7 +42,7 @@ limitations under the License.
namespace tensorflow {
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
// This pass implements rewriting of graph to support following scenarios:
// (A) Merging nodes in the graph
@@ -2211,7 +2211,7 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
return Status::OK();
}
-#else // INTEL_MKL_DNN
+#else // INTEL_MKL_ML
// This pass implements rewriting of graph to support following scenarios:
// (A) Merging nodes in the graph
@@ -2452,9 +2452,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// NOTE: names are alphabetically sorted.
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAddN, AddNRewrite});
- /* rinfo_.push_back({csinfo_.add,
- mkl_op_registry::GetMklOpName(csinfo_.add),
- CopyAttrsDataType, AlwaysRewrite}); */
+ rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
+ CopyAttrsDataType, AlwaysRewrite});
rinfo_.push_back({csinfo_.avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
CopyAttrsPooling, AlwaysRewrite});
@@ -2502,14 +2501,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, AlwaysRewrite});
- /*
+
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite});
rinfo_.push_back({csinfo_.mul,
mkl_op_registry::GetMklOpName(csinfo_.mul),
CopyAttrsDataType, AlwaysRewrite});
- */
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsDataType, AlwaysRewrite});
rinfo_.push_back({csinfo_.relu_grad,
@@ -2529,14 +2527,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsDataType, AlwaysRewrite});
- /*
+
rinfo_.push_back({csinfo_.squared_difference,
mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
CopyAttrsDataType, AlwaysRewrite});
rinfo_.push_back({csinfo_.sub,
mkl_op_registry::GetMklOpName(csinfo_.sub),
CopyAttrsDataType, AlwaysRewrite});
- */
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
@@ -4317,7 +4314,7 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
return Status::OK();
}
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
} // namespace tensorflow
#endif
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 320d5a48c7..5e2a465e22 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -38,7 +38,7 @@ limitations under the License.
namespace tensorflow {
-#ifndef INTEL_MKL_DNN
+#ifdef INTEL_MKL_ML
namespace {
@@ -1899,7 +1899,7 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace
-#else // INTEL_MKL_DNN
+#else // INTEL_MKL_ML
namespace {
@@ -3532,7 +3532,7 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace
-#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL_ML
} // namespace tensorflow
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index d5b026eae3..0d88d1ff72 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -273,6 +273,16 @@ Node* Reverse(Graph* g, Node* tensor, Node* axis) {
return Binary(g, "ReverseV2", tensor, axis);
}
+Node* Roll(Graph* g, Node* input, Node* shift, Node* axis) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Roll", g->op_registry())
+ .Input(input)
+ .Input(shift)
+ .Input(axis)
+ .Finalize(g, &ret));
+ return ret;
+}
+
Node* Error(Graph* g, Node* input, const string& errmsg) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "Error")
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 06597778bb..eb9038d619 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -117,6 +117,10 @@ Node* RandomGamma(Graph* g, Node* shape, Node* alpha);
// Output dtype determined by lam.
Node* RandomPoisson(Graph* g, Node* shape, Node* lam);
+// Rolls tensor by an offset of <shift> along the corresponding
+// <axis> dimensions.
+Node* Roll(Graph* g, Node* input, Node* shift, Node* axis);
+
// Generates random parameters from the truncated standard normal distribution
// of the nput shape
Node* TruncatedNormal(Graph* g, Node* input, DataType dtype);