diff options
-rw-r--r-- | tensorflow/core/kernels/fractional_max_pool_op.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 48 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops_test.cc | 60 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 31 |
4 files changed, 114 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/fractional_max_pool_op.cc b/tensorflow/core/kernels/fractional_max_pool_op.cc index 482491b504..a422433ecf 100644 --- a/tensorflow/core/kernels/fractional_max_pool_op.cc +++ b/tensorflow/core/kernels/fractional_max_pool_op.cc @@ -72,6 +72,8 @@ class FractionalMaxPoolOp : public OpKernel { } // Output size. for (int i = 0; i < tensor_in_and_out_dims; ++i) { + // This must match the same logic in the shape function in + // core/ops/nn_ops.cc. output_size_.push_back( static_cast<int>(floor(input_size_[i] / pooling_ratio_[i]))); DCHECK_GT(output_size_[i], 0); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 9cf1ff766a..9fea2213a5 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -45,6 +45,39 @@ Status InputTensorShapeOrUnknown(InferenceContext* c, int input_idx, return Status::OK(); } +Status FractionalPoolShapeFn(InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + + std::vector<float> pooling_ratio; + TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio)); + if (pooling_ratio.size() != 4) { + return errors::InvalidArgument( + "pooling_ratio field must specify 4 dimensions"); + } + std::vector<DimensionHandle> output_dims; + for (int i = 0; i < 4; ++i) { + DimensionHandle d = c->Dim(input, i); + if (c->ValueKnown(d)) { + // This must match the same logic in the kernel function in + // core/kernels/fractional_max_pool_op.cc. + auto val = static_cast<int64>(floor(c->Value(d) / pooling_ratio[i])); + if (val < 0) { + return errors::InvalidArgument("Size computed for dim ", i, + " is negative: ", val); + } + output_dims.push_back(c->MakeDim(val)); + } else { + output_dims.push_back(c->UnknownDim()); + } + } + + c->set_output(0, c->MakeShape(output_dims)); + c->set_output(1, c->Vector(output_dims[1])); + c->set_output(2, c->Vector(output_dims[2])); + return Status::OK(); +} + } // namespace // -------------------------------------------------------------------------- @@ -1572,6 +1605,7 @@ REGISTER_OP("FractionalMaxPool") .Attr("seed: int = 0") .Attr("seed2: int = 0") .Attr("T: {float, double, int32, int64}") + .SetShapeFn(FractionalPoolShapeFn) .Doc(R"doc( Performs fractional max pooling on the input. @@ -1646,6 +1680,9 @@ REGISTER_OP("FractionalMaxPoolGrad") .Output("output: T") .Attr("overlapping: bool = false") .Attr("T: {float, double, int32, int64}") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::UnchangedShapeWithRank(c, 4); + }) .Doc(R"doc( Computes gradient of the FractionalMaxPool function. @@ -1683,6 +1720,7 @@ REGISTER_OP("FractionalAvgPool") .Attr("seed: int = 0") .Attr("seed2: int = 0") .Attr("T: {float, double, int32, int64}") + .SetShapeFn(FractionalPoolShapeFn) .Doc(R"doc( Performs fractional average pooling on the input. @@ -1731,6 +1769,16 @@ REGISTER_OP("FractionalAvgPoolGrad") .Output("output: T") .Attr("overlapping: bool = false") .Attr("T: {float, double, int32, int64}") + .SetShapeFn([](InferenceContext* c) { + if (c->input_tensor(0) != nullptr) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); + } else { + c->set_output(0, c->UnknownShapeOfRank(4)); + } + return Status::OK(); + }) .Doc(R"doc( Computes gradient of the FractionalAvgPool function. diff --git a/tensorflow/core/ops/nn_ops_test.cc b/tensorflow/core/ops/nn_ops_test.cc index 0251b15058..d804cf8281 100644 --- a/tensorflow/core/ops/nn_ops_test.cc +++ b/tensorflow/core/ops/nn_ops_test.cc @@ -338,4 +338,64 @@ TEST(NNOpsTest, Dilation2DShapeTest) { INFER_OK(op, "[1,7,7,2];[2,2,2]", "[d0_0,5,5,d1_2]"); } +TEST(NNOpsTest, FractionalPool_ShapeFn) { + for (const char* op_name : {"FractionalAvgPool", "FractionalMaxPool"}) { + ShapeInferenceTestOp op(op_name); + auto set_op = [&op, op_name](const std::vector<float>& pooling_ratio) { + TF_ASSERT_OK(NodeDefBuilder("test", op_name) + .Input("input", 0, DT_FLOAT) + .Attr("pooling_ratio", pooling_ratio) + .Finalize(&op.node_def)); + }; + + set_op(std::vector<float>{2.0, 1, 1 / 1.5, 1 / 2.0}); + + // Rank check. + INFER_ERROR("must be rank 4", op, "[?,?,?]"); + + // Unknown inputs. + INFER_OK(op, "?", "[?,?,?,?];[?];[?]"); + INFER_OK(op, "[?,?,?,?]", "[?,?,?,?];[?];[?]"); + + INFER_OK(op, "[10,20,30,40]", "[5,20,45,80];[20];[45]"); + INFER_OK(op, "[?,20,30,40]", "[?,20,45,80];[20];[45]"); + INFER_OK(op, "[10,?,30,40]", "[5,?,45,80];[?];[45]"); + INFER_OK(op, "[10,20,?,40]", "[5,20,?,80];[20];[?]"); + INFER_OK(op, "[10,20,30,?]", "[5,20,45,?];[20];[45]"); + + // Wrong number of values for pooling_ratio. + set_op(std::vector<float>{.5, 1.0, 1.5}); + INFER_ERROR("pooling_ratio field", op, "?"); + set_op(std::vector<float>{1, 2, 3, 4, 5}); + INFER_ERROR("pooling_ratio field", op, "?"); + + // Check dim size >= 0. + set_op(std::vector<float>{-1, 2, 3, 4}); + INFER_ERROR("is negative", op, "[1,2,3,4]"); + } +} + +TEST(NNOpsTest, FractionalMaxPoolGrad) { + ShapeInferenceTestOp op("FractionalMaxPoolGrad"); + + // Note that the shape fn only uses input[0] for computation. + INFER_ERROR("must be rank 4", op, "[?,?,?];?;?;?;?"); + INFER_OK(op, "?;?;?;?;?", "[?,?,?,?]"); + INFER_OK(op, "[?,?,3,4];?;?;?;?", "in0"); +} + +TEST(NNOpsTest, FractionalAvgPoolGrad) { + ShapeInferenceTestOp op("FractionalAvgPoolGrad"); + op.input_tensors.resize(1); + + // With no input shape tensor, returns unknown of rank 4. + INFER_OK(op, "?;?;?;?", "[?,?,?,?]"); + + // When input tensor is known, its values determine output shape. + std::vector<int32> shape{1, 2, 3, 4}; + Tensor shape_t = test::AsTensor<int32>(shape); + op.input_tensors[0] = &shape_t; + INFER_OK(op, "[5];?;?;?", "[1,2,3,4]"); +} + } // end namespace tensorflow diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 9560ffe66d..20ef9f4c0a 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -953,37 +953,14 @@ def _AvgPoolGradShape(op): return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0]) -@ops.RegisterShape("FractionalMaxPool") -@ops.RegisterShape("FractionalAvgPool") -def _fractional_pool_shape(op): - input_dims = op.inputs[0].get_shape().with_rank(4).as_list() - pooling_ratio = op.get_attr("pooling_ratio") - output_dims = np.divide(input_dims, pooling_ratio).astype(int) - return [ - # output. - tensor_shape.TensorShape(output_dims), - # row_pooling_sequence. - tensor_shape.TensorShape([output_dims[1]]), - # col_pooling_sequence. - tensor_shape.TensorShape([output_dims[2]]) - ] - - -@ops.RegisterShape("FractionalMaxPoolGrad") -def _fractional_max_pool_grad_shape(op): - """Shape function for the FractionalMaxPoolGrad op.""" - orig_input_shape = op.inputs[0].get_shape().with_rank(4) - return [orig_input_shape] +ops.RegisterShape("FractionalMaxPool")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("FractionalAvgPool")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("FractionalMaxPoolGrad")(common_shapes.call_cpp_shape_fn) @ops.RegisterShape("FractionalAvgPoolGrad") def _fractional_avg_pool_grad_shape(op): - """Shape function for the FractionalAvgPoolGrad op.""" - orig_input_shape = tensor_util.constant_value(op.inputs[0]) - if orig_input_shape is not None: - return [tensor_shape.TensorShape(orig_input_shape.tolist())] - else: - return [tensor_shape.unknown_shape(ndims=4)] + return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0]) @ops.RegisterShape("Conv2DBackpropFilter") |