aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 19:00:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 19:00:50 -0700
commitc4156ee08bed83ce54ab14a606af498dc8ebdbe6 (patch)
treea41fbe5865114bb3a1650a5173dd0e244f0896b9 /tensorflow/core/graph
parentfa607e7e9224b4d88ead0a81fc65c7884d25950a (diff)
parent0fb7fcaa22c7d4167b4586c8a44f08b8830c0471 (diff)
Merge pull request #21586 from Intel-tensorflow:pooling3d
PiperOrigin-RevId: 210474549
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc28
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc12
2 files changed, 35 insertions, 5 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 833592caab..7e501c1717 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -334,6 +334,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
+
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
@@ -546,14 +547,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
- !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}
// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
- !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
@@ -2408,6 +2409,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.addn = "AddN";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
+ csinfo_.avg_pool3d = "AvgPool3D";
+ csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
@@ -2429,6 +2432,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.matmul = "MatMul";
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
+ csinfo_.max_pool3d = "MaxPool3D";
+ csinfo_.max_pool3d_grad = "MaxPool3DGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
@@ -2463,6 +2468,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.avg_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
+ CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.concat,
mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite});
@@ -2513,7 +2524,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, MaxpoolGradRewrite});
-
+ rinfo_.push_back({csinfo_.max_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
+ CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
+ rinfo_.push_back({csinfo_.max_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite});
@@ -2550,6 +2566,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
+ wsinfo_.push_back
+ ({csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
// Add a rule for merging nodes
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
@@ -2617,6 +2635,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string add;
string avg_pool;
string avg_pool_grad;
+ string avg_pool3d;
+ string avg_pool3d_grad;
string bias_add;
string bias_add_grad;
string concat;
@@ -2637,6 +2657,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string matmul;
string max_pool;
string max_pool_grad;
+ string max_pool3d;
+ string max_pool3d_grad;
string maximum;
string mkl_conv2d;
string mkl_conv2d_grad_input;
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index aa39af637f..b67a321fc1 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -175,7 +175,11 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
- if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
+ // TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
+ // using data_format. This code might be redundant.
+ if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
@@ -254,9 +258,13 @@ Status MklToTfConversionPass::InsertInputConversionNode(
}
}
+ // TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
+ // seem to be using data_format. This code might be redundant.
string data_format;
if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
- Status::OK()) {
+ Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}