diff options
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 338 |
1 files changed, 326 insertions, 12 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index e56b27b0c0..e9d5897af0 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -89,7 +89,7 @@ REGISTER_OP("AvgPool") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {float, half, double}") + .Attr("T: realnumbertype") .SetShapeFn(shape_inference::AvgPoolShape) .Doc(R"doc( Performs average pooling on the input. @@ -117,7 +117,7 @@ REGISTER_OP("AvgPoolGrad") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {float, half, double}") + .Attr("T: realnumbertype") .SetShapeFn([](InferenceContext* c) { // NOTE(mrry): We could in principle work out the shape from the // gradients and the attrs, but if we do not know orig_input_shape @@ -1186,15 +1186,16 @@ data_format: The data format of the input and output data. With the )doc"); REGISTER_OP("MaxPool3DGrad") - .Input("orig_input: float") - .Input("orig_output: float") + .Input("orig_input: TInput") + .Input("orig_output: TInput") .Input("grad: T") .Output("output: T") .Attr("ksize: list(int) >= 5 ") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr(GetConvnet3dDataFormatAttrString()) - .Attr("T: numbertype") + .Attr("T: numbertype = DT_FLOAT") + .Attr("TInput: numbertype = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { return UnchangedShapeWithRank(c, 5); }) @@ -1216,6 +1217,44 @@ data_format: The data format of the input and output data. With the [batch, in_channels, in_depth, in_height, in_width]. )doc"); +REGISTER_OP("MaxPool3DGradGrad") + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Output("output: T") + .Attr("ksize: list(int) >= 5 ") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c)); + ShapeHandle unused; + // Validate 'orig_input' is the same shape as 'grad' + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); + // Validate 'orig_output' is same shape as 'output' + TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); + return Status::OK(); + }) + .Doc(R"doc( +Computes second-order gradients of the maxpooling function. + +ksize: 1-D tensor of length 5. The size of the window for each dimension of + the input tensor. Must have `ksize[0] = ksize[4] = 1`. +strides: 1-D tensor of length 5. The stride of the sliding window for each + dimension of `input`. Must have `strides[0] = strides[4] = 1`. +padding: The type of padding algorithm to use. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: Output backprop of shape `[batch, depth, rows, cols, channels]`. +output: Gradients of gradients w.r.t. the input to `max_pool`. +data_format: The data format of the input and output data. With the + default format "NDHWC", the data is stored in the order of: + [batch, in_depth, in_height, in_width, in_channels]. + Alternatively, the format could be "NCDHW", the data storage order is: + [batch, in_channels, in_depth, in_height, in_width]. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("L2Loss") @@ -1303,7 +1342,7 @@ output: The gradients for LRN. // -------------------------------------------------------------------------- REGISTER_OP("MaxPool") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: realnumbertype = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) @@ -1336,7 +1375,7 @@ REGISTER_OP("MaxPoolGrad") .Input("orig_output: T") .Input("grad: T") .Output("output: T") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: realnumbertype = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { return UnchangedShapeWithRank(c, 4); }) @@ -1358,6 +1397,43 @@ grad: 4-D. Gradients w.r.t. the output of `max_pool`. output: Gradients w.r.t. the input to `max_pool`. )doc"); +REGISTER_OP("MaxPoolGradGrad") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Output("output: T") + .Attr("T: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); + ShapeHandle unused; + // Validate 'orig_input' is the same shape as 'grad' + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); + // Validate 'orig_output' is same shape as 'output' + TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); + return Status::OK(); + }) + .Doc(R"doc( +Computes second-order gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +data_format: Specify the data format of the input and output data. With the + default format "NHWC", the data is stored in the order of: + [batch, in_height, in_width, in_channels]. + Alternatively, the format could be "NCHW", the data storage order of: + [batch, in_channels, in_height, in_width]. +orig_input: The original input tensor. +orig_output: The original output tensor. +grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +output: Gradients of gradients w.r.t. the input to `max_pool`. +)doc"); + REGISTER_OP("MaxPoolWithArgmax") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") @@ -1366,7 +1442,7 @@ REGISTER_OP("MaxPoolWithArgmax") .Input("input: T") .Output("output: T") .Output("argmax: Targmax") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: realnumbertype") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); c->set_output(1, c->output(0)); @@ -1397,7 +1473,7 @@ REGISTER_OP("MaxPoolGradWithArgmax") .Input("grad: T") .Input("argmax: Targmax") .Output("output: T") - .Attr("T: {float, half} = DT_FLOAT") + .Attr("T: realnumbertype") .SetShapeFn([](InferenceContext* c) { return UnchangedShapeWithRank(c, 4); }) @@ -1415,6 +1491,39 @@ argmax: The indices of the maximum values chosen for each output of `max_pool`. output: Gradients w.r.t. the input of `max_pool`. )doc"); +REGISTER_OP("MaxPoolGradGradWithArgmax") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr("Targmax: {int32, int64}") + .Input("input: T") + .Input("grad: T") + .Input("argmax: Targmax") + .Output("output: T") + .Attr("T: realnumbertype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); + ShapeHandle unused; + // Validate 'orig_input' is the same shape as 'grad' + TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused)); + // Validate 'argmax' is same shape as 'output' + TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused)); + return Status::OK(); + }) + .Doc(R"doc( +Computes second-order gradients of the maxpooling function. + +ksize: The size of the window for each dimension of the input tensor. +strides: The stride of the sliding window for each dimension of the + input tensor. +padding: The type of padding algorithm to use. +input: The original input. +grad: 4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the + input of `max_pool`. +argmax: The indices of the maximum values chosen for each output of `max_pool`. +output: Gradients of gradients w.r.t. the input of `max_pool`. +)doc"); + // -------------------------------------------------------------------------- REGISTER_OP("Dilation2D") @@ -2517,7 +2626,10 @@ REGISTER_OP("MklConv2D") .Attr(GetConvnetDataFormatAttrString()) .SetShapeFn(shape_inference::Conv2DShape) .Doc(R"doc( -MKL version of Conv2D +MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. )doc"); REGISTER_OP("MklConv2DWithBias") @@ -2533,14 +2645,216 @@ REGISTER_OP("MklConv2DWithBias") .Attr("strides: list(int)") .Attr("use_cudnn_on_gpu: bool = true") .Attr(GetPaddingAttrString()) - .Attr(GetConvnetDataFormatAttrString()); + .Attr(GetConvnetDataFormatAttrString()) + .Doc(R"doc( +MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform +2D convolution and add Bias to the output of convolution. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklConv2DBackpropFilter") + .Input("input: T") + .Input("mkl_input: uint8") + .Input("filter_sizes: int32") + .Input("mkl_filter_size: uint8") + .Input("out_backprop: T") + .Input("mkl_out_backprop: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn([](InferenceContext* c) { + return InputTensorShapeOrUnknown(c, 2 /* input_idx */, 4 /* ndims */); + }) + .Doc(R"doc( +MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the +gradients of convolution with respect to the filter. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklConv2DWithBiasBackpropBias") + .Input("out_backprop: T") + .Input("mkl_out_backprop: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int)") + .Attr(GetConvnetDataFormatAttrString()) + .Doc(R"doc( +MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the +gradients of convolution with respect to the bias. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklConv2DBackpropInput") + .Input("input_sizes: int32") + .Input("mkl_input_sizes: uint8") + .Input("filter: T") + .Input("mkl_filter: uint8") + .Input("out_backprop: T") + .Input("mkl_out_backprop: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int)") + .Attr("use_cudnn_on_gpu: bool = true") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .SetShapeFn([](InferenceContext* c) { + return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */); + }) + .Doc(R"doc( +MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the +gradients of convolution with respect to the input. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklRelu") + .Input("features: T") + .Input("mkl_features: uint8") + .Output("activations: T") + .Output("mkl_activations: uint8") + .Attr("T: realnumbertype") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklReluGrad") + .Input("gradients: T") + .Input("mkl_gradients: uint8") + .Input("features: T") + .Input("mkl_features: uint8") + .Output("backprops: T") + .Output("mkl_backprops: uint8") + .Attr("T: realnumbertype") + .SetShapeFn(shape_inference::MergeBothInputsShapeFn) + .Doc(R"doc( +MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified +linear gradients for Relu operation. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklMaxPool") + .Attr("T: {float, half} = DT_FLOAT") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("workspace_enabled: bool = false") + .Input("input: T") + .Input("mkl_input: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Output("workspace: T") + .Output("mkl_workspace: uint8") + .SetShapeFn(shape_inference::MaxPoolShape) + .Doc(R"doc( +MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling +on the input. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklMaxPoolGrad") + .Attr("T: {float, half} = DT_FLOAT") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr("workspace_enabled: bool = false") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Input("orig_input: T") + .Input("mkl_orig_input: uint8") + .Input("orig_output: T") + .Input("mkl_orig_output: uint8") + .Input("grad: T") + .Input("mkl_grad: uint8") + .Input("workspace: T") + .Input("mkl_workspace: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 4); + }) + .Doc(R"doc( +MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of +MaxPool operator. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklAvgPool") + .Input("value: T") + .Input("mkl_input: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("T: {float, half, double}") + .SetShapeFn(shape_inference::AvgPoolShape) + .Doc(R"doc( +MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling +on the input. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("MklAvgPoolGrad") + .Input("orig_input_shape: int32") + .Input("mkl_orig_input: uint8") + .Input("grad: T") + .Input("mkl_grad: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("ksize: list(int) >= 4") + .Attr("strides: list(int) >= 4") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("T: {float, half, double}") + .SetShapeFn([](InferenceContext* c) { + return InputTensorShapeOrUnknown(c, 0 /* input_idx */, 4 /* ndims */); + }) + .Doc(R"doc( +MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients +of AvgPool function. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); REGISTER_OP("MklToTf") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") .Attr("T: {half, float, double}") - .Attr(GetConvnetDataFormatAttrString()); + .Attr(GetConvnetDataFormatAttrString()) + .Doc(R"doc( +MKL operator to convert a tensor from MKL layout to TensorFlow layout. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); #endif // INTEL_MKL } // namespace tensorflow |