aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-08-13 15:52:43 -0700
committerGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-08-13 15:52:43 -0700
commit135ac89cae38464a9c6ea21af244e4a1bda255ed (patch)
tree8542b2df9e68d511718ad14e393e55e4065fca8c /tensorflow/core/graph
parent087190246b95dc4c188f630ca90880a12e39b557 (diff)
enable pooling3D op
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc28
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc12
2 files changed, 35 insertions, 5 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 5683944e46..30e48d3860 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -334,6 +334,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
+
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
@@ -546,14 +547,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
- !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}
// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
- !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
@@ -2408,6 +2409,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.addn = "AddN";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
+ csinfo_.avg_pool3d = "AvgPool3D";
+ csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
@@ -2426,6 +2429,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.matmul = "MatMul";
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
+ csinfo_.max_pool3d = "MaxPool3D";
+ csinfo_.max_pool3d_grad = "MaxPool3DGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
@@ -2460,6 +2465,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.avg_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
+ CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.concat,
mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite});
@@ -2501,7 +2512,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, MaxpoolGradRewrite});
-
+ rinfo_.push_back({csinfo_.max_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
+ CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
+ rinfo_.push_back({csinfo_.max_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite});
@@ -2538,6 +2554,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
+ wsinfo_.push_back
+ ({csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
// Add a rule for merging nodes
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
@@ -2605,6 +2623,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string add;
string avg_pool;
string avg_pool_grad;
+ string avg_pool3d;
+ string avg_pool3d_grad;
string bias_add;
string bias_add_grad;
string concat;
@@ -2622,6 +2642,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string matmul;
string max_pool;
string max_pool_grad;
+ string max_pool3d;
+ string max_pool3d_grad;
string maximum;
string mkl_conv2d;
string mkl_conv2d_grad_input;
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index aa39af637f..b67a321fc1 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -175,7 +175,11 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
- if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
+ // TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
+ // using data_format. This code might be redundant.
+ if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
@@ -254,9 +258,13 @@ Status MklToTfConversionPass::InsertInputConversionNode(
}
}
+ // TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
+ // seem to be using data_format. This code might be redundant.
string data_format;
if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
- Status::OK()) {
+ Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}