aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/nn_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r--tensorflow/core/ops/nn_ops.cc338
1 files changed, 326 insertions, 12 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index e56b27b0c0..e9d5897af0 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -89,7 +89,7 @@ REGISTER_OP("AvgPool")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {float, half, double}")
+ .Attr("T: realnumbertype")
.SetShapeFn(shape_inference::AvgPoolShape)
.Doc(R"doc(
Performs average pooling on the input.
@@ -117,7 +117,7 @@ REGISTER_OP("AvgPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {float, half, double}")
+ .Attr("T: realnumbertype")
.SetShapeFn([](InferenceContext* c) {
// NOTE(mrry): We could in principle work out the shape from the
// gradients and the attrs, but if we do not know orig_input_shape
@@ -1186,15 +1186,16 @@ data_format: The data format of the input and output data. With the
)doc");
REGISTER_OP("MaxPool3DGrad")
- .Input("orig_input: float")
- .Input("orig_output: float")
+ .Input("orig_input: TInput")
+ .Input("orig_output: TInput")
.Input("grad: T")
.Output("output: T")
.Attr("ksize: list(int) >= 5 ")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: numbertype")
+ .Attr("T: numbertype = DT_FLOAT")
+ .Attr("TInput: numbertype = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
})
@@ -1216,6 +1217,44 @@ data_format: The data format of the input and output data. With the
[batch, in_channels, in_depth, in_height, in_width].
)doc");
+REGISTER_OP("MaxPool3DGradGrad")
+ .Input("orig_input: T")
+ .Input("orig_output: T")
+ .Input("grad: T")
+ .Output("output: T")
+ .Attr("ksize: list(int) >= 5 ")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: realnumbertype")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
+ ShapeHandle unused;
+ // Validate 'orig_input' is the same shape as 'grad'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
+ // Validate 'orig_output' is same shape as 'output'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Computes second-order gradients of the maxpooling function.
+
+ksize: 1-D tensor of length 5. The size of the window for each dimension of
+ the input tensor. Must have `ksize[0] = ksize[4] = 1`.
+strides: 1-D tensor of length 5. The stride of the sliding window for each
+ dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+padding: The type of padding algorithm to use.
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: Output backprop of shape `[batch, depth, rows, cols, channels]`.
+output: Gradients of gradients w.r.t. the input to `max_pool`.
+data_format: The data format of the input and output data. With the
+ default format "NDHWC", the data is stored in the order of:
+ [batch, in_depth, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCDHW", the data storage order is:
+ [batch, in_channels, in_depth, in_height, in_width].
+)doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("L2Loss")
@@ -1303,7 +1342,7 @@ output: The gradients for LRN.
// --------------------------------------------------------------------------
REGISTER_OP("MaxPool")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: realnumbertype = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
@@ -1336,7 +1375,7 @@ REGISTER_OP("MaxPoolGrad")
.Input("orig_output: T")
.Input("grad: T")
.Output("output: T")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: realnumbertype = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
})
@@ -1358,6 +1397,43 @@ grad: 4-D. Gradients w.r.t. the output of `max_pool`.
output: Gradients w.r.t. the input to `max_pool`.
)doc");
+REGISTER_OP("MaxPoolGradGrad")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Input("orig_input: T")
+ .Input("orig_output: T")
+ .Input("grad: T")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
+ ShapeHandle unused;
+ // Validate 'orig_input' is the same shape as 'grad'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
+ // Validate 'orig_output' is same shape as 'output'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Computes second-order gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+data_format: Specify the data format of the input and output data. With the
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+orig_input: The original input tensor.
+orig_output: The original output tensor.
+grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`.
+output: Gradients of gradients w.r.t. the input to `max_pool`.
+)doc");
+
REGISTER_OP("MaxPoolWithArgmax")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
@@ -1366,7 +1442,7 @@ REGISTER_OP("MaxPoolWithArgmax")
.Input("input: T")
.Output("output: T")
.Output("argmax: Targmax")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: realnumbertype")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
c->set_output(1, c->output(0));
@@ -1397,7 +1473,7 @@ REGISTER_OP("MaxPoolGradWithArgmax")
.Input("grad: T")
.Input("argmax: Targmax")
.Output("output: T")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: realnumbertype")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
})
@@ -1415,6 +1491,39 @@ argmax: The indices of the maximum values chosen for each output of `max_pool`.
output: Gradients w.r.t. the input of `max_pool`.
)doc");
+REGISTER_OP("MaxPoolGradGradWithArgmax")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr("Targmax: {int32, int64}")
+ .Input("input: T")
+ .Input("grad: T")
+ .Input("argmax: Targmax")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .SetShapeFn([](InferenceContext* c) {
+ TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
+ ShapeHandle unused;
+ // Validate 'orig_input' is the same shape as 'grad'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused));
+ // Validate 'argmax' is same shape as 'output'
+ TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Computes second-order gradients of the maxpooling function.
+
+ksize: The size of the window for each dimension of the input tensor.
+strides: The stride of the sliding window for each dimension of the
+ input tensor.
+padding: The type of padding algorithm to use.
+input: The original input.
+grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the
+ input of `max_pool`.
+argmax: The indices of the maximum values chosen for each output of `max_pool`.
+output: Gradients of gradients w.r.t. the input of `max_pool`.
+)doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("Dilation2D")
@@ -2517,7 +2626,10 @@ REGISTER_OP("MklConv2D")
.Attr(GetConvnetDataFormatAttrString())
.SetShapeFn(shape_inference::Conv2DShape)
.Doc(R"doc(
-MKL version of Conv2D
+MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
)doc");
REGISTER_OP("MklConv2DWithBias")
@@ -2533,14 +2645,216 @@ REGISTER_OP("MklConv2DWithBias")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
- .Attr(GetConvnetDataFormatAttrString());
+ .Attr(GetConvnetDataFormatAttrString())
+ .Doc(R"doc(
+MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
+2D convolution and add Bias to the output of convolution.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklConv2DBackpropFilter")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Input("filter_sizes: int32")
+ .Input("mkl_filter_size: uint8")
+ .Input("out_backprop: T")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ return InputTensorShapeOrUnknown(c, 2 /* input_idx */, 4 /* ndims */);
+ })
+ .Doc(R"doc(
+MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the filter.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklConv2DWithBiasBackpropBias")
+ .Input("out_backprop: T")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr(GetConvnetDataFormatAttrString())
+ .Doc(R"doc(
+MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the bias.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklConv2DBackpropInput")
+ .Input("input_sizes: int32")
+ .Input("mkl_input_sizes: uint8")
+ .Input("filter: T")
+ .Input("mkl_filter: uint8")
+ .Input("out_backprop: T")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr("use_cudnn_on_gpu: bool = true")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ })
+ .Doc(R"doc(
+MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklRelu")
+ .Input("features: T")
+ .Input("mkl_features: uint8")
+ .Output("activations: T")
+ .Output("mkl_activations: uint8")
+ .Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklReluGrad")
+ .Input("gradients: T")
+ .Input("mkl_gradients: uint8")
+ .Input("features: T")
+ .Input("mkl_features: uint8")
+ .Output("backprops: T")
+ .Output("mkl_backprops: uint8")
+ .Attr("T: realnumbertype")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
+ .Doc(R"doc(
+MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
+linear gradients for Relu operation.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklMaxPool")
+ .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr("workspace_enabled: bool = false")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Output("workspace: T")
+ .Output("mkl_workspace: uint8")
+ .SetShapeFn(shape_inference::MaxPoolShape)
+ .Doc(R"doc(
+MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklMaxPoolGrad")
+ .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr("workspace_enabled: bool = false")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Input("orig_input: T")
+ .Input("mkl_orig_input: uint8")
+ .Input("orig_output: T")
+ .Input("mkl_orig_output: uint8")
+ .Input("grad: T")
+ .Input("mkl_grad: uint8")
+ .Input("workspace: T")
+ .Input("mkl_workspace: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 4);
+ })
+ .Doc(R"doc(
+MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of
+MaxPool operator.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklAvgPool")
+ .Input("value: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn(shape_inference::AvgPoolShape)
+ .Doc(R"doc(
+MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("MklAvgPoolGrad")
+ .Input("orig_input_shape: int32")
+ .Input("mkl_orig_input: uint8")
+ .Input("grad: T")
+ .Input("mkl_grad: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 4")
+ .Attr("strides: list(int) >= 4")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnetDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */);
+ })
+ .Doc(R"doc(
+MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients
+of AvgPool function.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
REGISTER_OP("MklToTf")
.Input("input: T")
.Input("mkl_input: uint8")
.Output("output: T")
.Attr("T: {half, float, double}")
- .Attr(GetConvnetDataFormatAttrString());
+ .Attr(GetConvnetDataFormatAttrString())
+ .Doc(R"doc(
+MKL operator to convert a tensor from MKL layout to TensorFlow layout.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
#endif // INTEL_MKL
} // namespace tensorflow