diff options
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 1018742521..0a96258dd1 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1368,6 +1368,34 @@ input: 4-D input to pool over. output: The max pooled output tensor. )doc"); +REGISTER_OP("MaxPoolV2") + .Attr("T: realnumbertype = DT_FLOAT") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("input: T") + .Input("ksize: int32") + .Input("strides: int32") + .Output("output: T") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3)); + return Status::OK(); + }) + .Doc(R"doc( +Performs max pooling on the input. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +data_format: Specify the data format of the input and output data. With the + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +input: 4-D input to pool over. +output: The max pooled output tensor. +)doc"); + REGISTER_OP("MaxPoolGrad") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") @@ -1399,6 +1427,37 @@ grad: 4-D. Gradients w.r.t. the output of `max_pool`. output: Gradients w.r.t. the input to `max_pool`. )doc"); +REGISTER_OP("MaxPoolGradV2") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Input("ksize: int32") + .Input("strides: int32") + .Output("output: T") + .Attr("T: realnumbertype = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 4); + }) + .Doc(R"doc( +Computes gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +data_format: Specify the data format of the input and output data. With the + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: 4-D. Gradients w.r.t. the output of `max_pool`. +output: Gradients w.r.t. the input to `max_pool`. +)doc"); + REGISTER_OP("MaxPoolGradGrad") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") @@ -1436,6 +1495,43 @@ grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. output: Gradients of gradients w.r.t. the input to `max_pool`. )doc"); +REGISTER_OP("MaxPoolGradGradV2") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Input("ksize: int32") + .Input("strides: int32") + .Output("output: T") + .Attr("T: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5)); + ShapeHandle unused; + // Validate 'orig_input' is the same shape as 'grad' + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); + // Validate 'orig_output' is same shape as 'output' + TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); + return Status::OK(); + }) + .Doc(R"doc( +Computes second-order gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +data_format: Specify the data format of the input and output data. With the + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +output: Gradients of gradients w.r.t. the input to `max_pool`. +)doc"); + REGISTER_OP("MaxPoolWithArgmax") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") |