aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-06 09:42:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-06 09:54:53 -0700
commit10594900c5df1b84cd0336d0fb5bd0d8454bfe08 (patch)
tree7445e120556edba062cf82f3b0ade447e8cb8362 /tensorflow/core/framework
parentb71c1bb6f5edd77f92a10e99d011a00de572aa68 (diff)
Update MaxPoolV2Shape to support NCHV_VECT_C.
PiperOrigin-RevId: 167732437
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc88
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc45
2 files changed, 85 insertions, 48 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 0e3ea2ddfb..2d44480053 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -559,9 +559,6 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
}
Status AvgPoolShape(shape_inference::InferenceContext* c) {
- ShapeHandle input_shape;
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
-
string data_format_str;
TensorFormat data_format;
Status s = c->GetAttr("data_format", &data_format_str);
@@ -571,6 +568,10 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
data_format = FORMAT_NHWC;
}
+ const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
+
TF_RETURN_IF_ERROR(
CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
@@ -627,9 +628,6 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
}
Status MaxPoolShape(shape_inference::InferenceContext* c) {
- ShapeHandle input_shape;
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
-
string data_format_str;
TensorFormat data_format;
Status s = c->GetAttr("data_format", &data_format_str);
@@ -639,6 +637,10 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
data_format = FORMAT_NHWC;
}
+ const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
+
TF_RETURN_IF_ERROR(
CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
@@ -696,11 +698,21 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
}
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
+ string data_format_str;
+ TensorFormat data_format;
+ Status s = c->GetAttr("data_format", &data_format_str);
+ if (s.ok()) {
+ FormatFromString(data_format_str, &data_format);
+ } else {
+ data_format = FORMAT_NHWC;
+ }
+
+ const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4;
ShapeHandle input_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape));
- string data_format;
- Status s = c->GetAttr("data_format", &data_format);
+ TF_RETURN_IF_ERROR(
+ CheckFormatConstraintsOnShape(data_format, input_shape, "input", c));
std::vector<int32> kernel_sizes;
std::vector<int32> strides;
@@ -725,7 +737,8 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
}
kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements());
auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>();
- std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin());
+ std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(),
+ kernel_sizes.begin());
const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1);
if (strides_tensor == nullptr) {
@@ -749,35 +762,22 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
kernel_sizes.size());
}
- int32 stride_rows, stride_cols, stride_depth;
- int32 kernel_rows, kernel_cols, kernel_depth;
-
- if (s.ok() && data_format == "NCHW") {
- // Canonicalize input shape to NHWC so the shape inference code below can
- // process it.
- auto dim = [&](char dimension) {
- return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension));
- };
- input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}});
- stride_depth = strides[1];
- stride_rows = strides[2];
- stride_cols = strides[3];
- kernel_depth = kernel_sizes[1];
- kernel_rows = kernel_sizes[2];
- kernel_cols = kernel_sizes[3];
- } else {
- stride_rows = strides[1];
- stride_cols = strides[2];
- stride_depth = strides[3];
- kernel_rows = kernel_sizes[1];
- kernel_cols = kernel_sizes[2];
- kernel_depth = kernel_sizes[3];
- }
+ int32 stride_depth = GetTensorDim(strides, data_format, 'C');
+ int32 stride_rows = GetTensorDim(strides, data_format, 'H');
+ int32 stride_cols = GetTensorDim(strides, data_format, 'W');
+ int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C');
+ int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H');
+ int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W');
- DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
- DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
- DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
- DimensionHandle in_depth_dim = c->Dim(input_shape, 3);
+ constexpr int num_spatial_dims = 2;
+ DimensionHandle batch_size_dim = c->Dim(
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N'));
+ DimensionHandle in_rows_dim = c->Dim(
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H'));
+ DimensionHandle in_cols_dim = c->Dim(
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W'));
+ DimensionHandle in_depth_dim = c->Dim(
+ input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C'));
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
@@ -791,15 +791,9 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) {
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth));
- output_shape =
- c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
- if (data_format == "NCHW") {
- // Convert output shape back to expected NCHW data format.
- auto dim = [&](char dimension) {
- return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension));
- };
- output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}});
- }
+ TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim,
+ {output_rows, output_cols},
+ output_depth, &output_shape, c));
c->set_output(0, output_shape);
return Status::OK();
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 14f6c1bb45..90a48f14d4 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/shape_inference_testutil.h"
@@ -704,7 +705,7 @@ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) {
INFER_ERROR("Dimension must be 4 but is 3", op, "[2,5,7,11,3]");
// Invalid rank for input
- INFER_ERROR("must be at least rank 4", op, "[4,4]");
+ INFER_ERROR("Shape must be rank", op, "[4,4]");
}
TEST(CommonShapeFnsTest, MaxPool2DShapeTest) {
@@ -741,6 +742,48 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) {
INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8]");
}
+TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) {
+ ShapeInferenceTestOp op("MaxPoolV2");
+ Tensor ksizes_tensor, strides_tensor;
+ auto set_op = [&op, &ksizes_tensor, &strides_tensor](
+ const std::vector<int32>& strides,
+ const std::vector<int32>& ksizes, const string& padding,
+ const string& data_format) {
+ TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2")
+ .Input("input", 0, DT_FLOAT)
+ .Input("ksize", 1, DT_INT32)
+ .Input("strides", 2, DT_INT32)
+ .Attr("padding", padding)
+ .Attr("data_format", data_format)
+ .Finalize(&op.node_def));
+ ksizes_tensor = test::AsTensor<int32>(ksizes);
+ op.input_tensors.resize(3);
+ op.input_tensors[0] = nullptr;
+ op.input_tensors[1] = &ksizes_tensor;
+ strides_tensor = test::AsTensor<int32>(strides);
+ op.input_tensors[2] = &strides_tensor;
+ };
+
+ // Most of the functionality is tested by conv-like shapes,
+ // so we check the very-specific maxpooling features here,
+ // namely depthwise kernel and striding.
+
+ // all 1 strides, depth 2 filter
+ set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC");
+ INFER_OK(op, "[1,2,2,2];[4];[4]", "[d0_0,2,2,1]");
+
+ // depth 3 stride, 1x1x1 filter, NCHW
+ set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW");
+ INFER_OK(op, "[1,7,5,5];[4];[4]", "[d0_0,3,5,5]");
+
+ // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests
+ set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C");
+ INFER_OK(op, "[2,3,5,7,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
+ INFER_OK(op, "[5,7,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
+ INFER_OK(op, "[?,?,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]");
+ INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8];[4];[4]");
+}
+
TEST(CommonShapeFnsTest, Pool3DShapeTest) {
ShapeInferenceTestOp op("MaxPool3D");
auto set_op = [&op](const std::vector<int32>& strides,