aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/fractional_max_pool_op.cc2
-rw-r--r--tensorflow/core/ops/nn_ops.cc48
-rw-r--r--tensorflow/core/ops/nn_ops_test.cc60
-rw-r--r--tensorflow/python/ops/nn_ops.py31
4 files changed, 114 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc
index 482491b504..a422433ecf 100644
--- a/tensorflow/core/kernels/fractional_max_pool_op.cc
+++ b/tensorflow/core/kernels/fractional_max_pool_op.cc
@@ -72,6 +72,8 @@ class FractionalMaxPoolOp : public OpKernel {
}
// Output size.
for (int i = 0; i < tensor_in_and_out_dims; ++i) {
+ // This must match the same logic in the shape function in
+ // core/ops/nn_ops.cc.
output_size_.push_back(
static_cast<int>(floor(input_size_[i] / pooling_ratio_[i])));
DCHECK_GT(output_size_[i], 0);
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 9cf1ff766a..9fea2213a5 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -45,6 +45,39 @@ Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx,
return Status::OK();
}
+Status FractionalPoolShapeFn(InferenceContext* c) {
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+
+ std::vector<float> pooling_ratio;
+ TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
+ if (pooling_ratio.size() != 4) {
+ return errors::InvalidArgument(
+ "pooling_ratio field must specify 4 dimensions");
+ }
+ std::vector<DimensionHandle> output_dims;
+ for (int i = 0; i < 4; ++i) {
+ DimensionHandle d = c->Dim(input, i);
+ if (c->ValueKnown(d)) {
+ // This must match the same logic in the kernel function in
+ // core/kernels/fractional_max_pool_op.cc.
+ auto val = static_cast<int64>(floor(c->Value(d) / pooling_ratio[i]));
+ if (val < 0) {
+ return errors::InvalidArgument("Size computed for dim ", i,
+ " is negative: ", val);
+ }
+ output_dims.push_back(c->MakeDim(val));
+ } else {
+ output_dims.push_back(c->UnknownDim());
+ }
+ }
+
+ c->set_output(0, c->MakeShape(output_dims));
+ c->set_output(1, c->Vector(output_dims[1]));
+ c->set_output(2, c->Vector(output_dims[2]));
+ return Status::OK();
+}
+
} // namespace
// --------------------------------------------------------------------------
@@ -1572,6 +1605,7 @@ REGISTER_OP("FractionalMaxPool")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("T: {float, double, int32, int64}")
+ .SetShapeFn(FractionalPoolShapeFn)
.Doc(R"doc(
Performs fractional max pooling on the input.
@@ -1646,6 +1680,9 @@ REGISTER_OP("FractionalMaxPoolGrad")
.Output("output: T")
.Attr("overlapping: bool = false")
.Attr("T: {float, double, int32, int64}")
+ .SetShapeFn([](InferenceContext* c) {
+ return shape_inference::UnchangedShapeWithRank(c, 4);
+ })
.Doc(R"doc(
Computes gradient of the FractionalMaxPool function.
@@ -1683,6 +1720,7 @@ REGISTER_OP("FractionalAvgPool")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("T: {float, double, int32, int64}")
+ .SetShapeFn(FractionalPoolShapeFn)
.Doc(R"doc(
Performs fractional average pooling on the input.
@@ -1731,6 +1769,16 @@ REGISTER_OP("FractionalAvgPoolGrad")
.Output("output: T")
.Attr("overlapping: bool = false")
.Attr("T: {float, double, int32, int64}")
+ .SetShapeFn([](InferenceContext* c) {
+ if (c->input_tensor(0) != nullptr) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
+ } else {
+ c->set_output(0, c->UnknownShapeOfRank(4));
+ }
+ return Status::OK();
+ })
.Doc(R"doc(
Computes gradient of the FractionalAvgPool function.
diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc
index 0251b15058..d804cf8281 100644
--- a/tensorflow/core/ops/nn_ops_test.cc
+++ b/tensorflow/core/ops/nn_ops_test.cc
@@ -338,4 +338,64 @@ TEST(NNOpsTest, Dilation2DShapeTest) {
INFER_OK(op, "[1,7,7,2];[2,2,2]", "[d0_0,5,5,d1_2]");
}
+TEST(NNOpsTest, FractionalPool_ShapeFn) {
+ for (const char* op_name : {"FractionalAvgPool", "FractionalMaxPool"}) {
+ ShapeInferenceTestOp op(op_name);
+ auto set_op = [&op, op_name](const std::vector<float>& pooling_ratio) {
+ TF_ASSERT_OK(NodeDefBuilder("test", op_name)
+ .Input("input", 0, DT_FLOAT)
+ .Attr("pooling_ratio", pooling_ratio)
+ .Finalize(&op.node_def));
+ };
+
+ set_op(std::vector<float>{2.0, 1, 1 / 1.5, 1 / 2.0});
+
+ // Rank check.
+ INFER_ERROR("must be rank 4", op, "[?,?,?]");
+
+ // Unknown inputs.
+ INFER_OK(op, "?", "[?,?,?,?];[?];[?]");
+ INFER_OK(op, "[?,?,?,?]", "[?,?,?,?];[?];[?]");
+
+ INFER_OK(op, "[10,20,30,40]", "[5,20,45,80];[20];[45]");
+ INFER_OK(op, "[?,20,30,40]", "[?,20,45,80];[20];[45]");
+ INFER_OK(op, "[10,?,30,40]", "[5,?,45,80];[?];[45]");
+ INFER_OK(op, "[10,20,?,40]", "[5,20,?,80];[20];[?]");
+ INFER_OK(op, "[10,20,30,?]", "[5,20,45,?];[20];[45]");
+
+ // Wrong number of values for pooling_ratio.
+ set_op(std::vector<float>{.5, 1.0, 1.5});
+ INFER_ERROR("pooling_ratio field", op, "?");
+ set_op(std::vector<float>{1, 2, 3, 4, 5});
+ INFER_ERROR("pooling_ratio field", op, "?");
+
+ // Check dim size >= 0.
+ set_op(std::vector<float>{-1, 2, 3, 4});
+ INFER_ERROR("is negative", op, "[1,2,3,4]");
+ }
+}
+
+TEST(NNOpsTest, FractionalMaxPoolGrad) {
+ ShapeInferenceTestOp op("FractionalMaxPoolGrad");
+
+ // Note that the shape fn only uses input[0] for computation.
+ INFER_ERROR("must be rank 4", op, "[?,?,?];?;?;?;?");
+ INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]");
+ INFER_OK(op, "[?,?,3,4];?;?;?;?", "in0");
+}
+
+TEST(NNOpsTest, FractionalAvgPoolGrad) {
+ ShapeInferenceTestOp op("FractionalAvgPoolGrad");
+ op.input_tensors.resize(1);
+
+ // With no input shape tensor, returns unknown of rank 4.
+ INFER_OK(op, "?;?;?;?", "[?,?,?,?]");
+
+ // When input tensor is known, its values determine output shape.
+ std::vector<int32> shape{1, 2, 3, 4};
+ Tensor shape_t = test::AsTensor<int32>(shape);
+ op.input_tensors[0] = &shape_t;
+ INFER_OK(op, "[5];?;?;?", "[1,2,3,4]");
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 9560ffe66d..20ef9f4c0a 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -953,37 +953,14 @@ def _AvgPoolGradShape(op):
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
-@ops.RegisterShape("FractionalMaxPool")
-@ops.RegisterShape("FractionalAvgPool")
-def _fractional_pool_shape(op):
- input_dims = op.inputs[0].get_shape().with_rank(4).as_list()
- pooling_ratio = op.get_attr("pooling_ratio")
- output_dims = np.divide(input_dims, pooling_ratio).astype(int)
- return [
- # output.
- tensor_shape.TensorShape(output_dims),
- # row_pooling_sequence.
- tensor_shape.TensorShape([output_dims[1]]),
- # col_pooling_sequence.
- tensor_shape.TensorShape([output_dims[2]])
- ]
-
-
-@ops.RegisterShape("FractionalMaxPoolGrad")
-def _fractional_max_pool_grad_shape(op):
- """Shape function for the FractionalMaxPoolGrad op."""
- orig_input_shape = op.inputs[0].get_shape().with_rank(4)
- return [orig_input_shape]
+ops.RegisterShape("FractionalMaxPool")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("FractionalAvgPool")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("FractionalMaxPoolGrad")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("FractionalAvgPoolGrad")
def _fractional_avg_pool_grad_shape(op):
- """Shape function for the FractionalAvgPoolGrad op."""
- orig_input_shape = tensor_util.constant_value(op.inputs[0])
- if orig_input_shape is not None:
- return [tensor_shape.TensorShape(orig_input_shape.tolist())]
- else:
- return [tensor_shape.unknown_shape(ndims=4)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
@ops.RegisterShape("Conv2DBackpropFilter")