diff options
author | 2018-09-11 19:59:11 +0800 | |
---|---|---|
committer | 2018-09-11 19:59:11 +0800 | |
commit | 9b3a93edf5a1f259bfe5230cc3b6c076573d4ec9 (patch) | |
tree | cbb0548282ba1584ed91a1be8f89b03ec882f287 /tensorflow/core/ops/nn_ops.cc | |
parent | 90cf7fb7786c8a9c135ef73482856b082e80f61a (diff) | |
parent | e18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff) |
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 187 |
1 files changed, 184 insertions, 3 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 023f988f80..6c318e358a 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -960,7 +960,7 @@ REGISTER_OP("Dilation2DBackpropFilter") REGISTER_OP("Relu") .Input("features: T") .Output("activations: T") - .Attr("T: realnumbertype") + .Attr("T: {realnumbertype, qint8}") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("ReluGrad") @@ -1024,6 +1024,7 @@ REGISTER_OP("SeluGrad") .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +// TODO(b/111515541): change T to {half, bfloat16, float, double} REGISTER_OP("Softplus") .Input("features: T") .Output("activations: T") @@ -1037,6 +1038,7 @@ REGISTER_OP("SoftplusGrad") .Attr("T: realnumbertype") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); +// TODO(b/111515541): change T to {half, bfloat16, float, double} REGISTER_OP("Softsign") .Input("features: T") .Output("activations: T") @@ -1751,6 +1753,87 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +REGISTER_OP("_MklConv3D") + .Input("input: T") + .Input("filter: T") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Output("output: T") + .Output("filter_output: T") + .Output("mkl_output: uint8") + .Output("mkl_filter_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") + .SetShapeFn(shape_inference::Conv3DShape) + .Doc(R"doc( +MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklConv3DBackpropInputV2") + .Input("input_sizes: Tshape") + .Input("filter: T") + .Input("out_backprop: T") + .Input("mkl_input_sizes: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_out_backprop: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int) >= 5") + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") + .Attr("Tshape: {int32, int64} = DT_INT32") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the +gradients of convolution with respect to the input. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklConv3DBackpropFilterV2") + .Input("input: T") + .Input("filter_sizes: int32") + .Input("out_backprop: T") + .Input("mkl_input: uint8") + .Input("mkl_filter_size: uint8") + .Input("mkl_out_backprop: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: {half, float, double}") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1, 1]") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the +gradients of convolution with respect to the filter. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + REGISTER_OP("_MklRelu") .Input("features: T") .Input("mkl_features: uint8") @@ -1958,6 +2041,104 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is expected to invoke these operators. )doc"); +REGISTER_OP("_MklAvgPool3D") + .Input("value: T") + .Input("mkl_input: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {float, half, double}") + .SetShapeFn(shape_inference::Pool3DShape) + .Doc(R"doc( +MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling +on the input. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + + +REGISTER_OP("_MklAvgPool3DGrad") + .Input("orig_input_shape: int32") + .Input("grad: T") + .Input("mkl_orig_input: uint8") + .Input("mkl_grad: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {float, half, double}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); + TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); + c->set_output(0, s); + return Status::OK(); + }) + .Doc(R"doc( +MKL version of AvgPool3DGrad operator. Uses MKL DNN APIs to compute gradients +of AvgPool function. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklMaxPool3D") + .Input("input: T") + .Input("mkl_input: uint8") + .Output("output: T") + .Output("workspace: uint8") + .Output("mkl_output: uint8") + .Output("mkl_workspace: uint8") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {half, bfloat16, float}") + .Attr("workspace_enabled: bool = false") + .SetShapeFn(shape_inference::Pool3DShape) + .Doc(R"doc( +MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling +on the input. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + +REGISTER_OP("_MklMaxPool3DGrad") + .Input("orig_input: TInput") + .Input("orig_output: TInput") + .Input("grad: T") + .Input("workspace: uint8") + .Input("mkl_orig_input: uint8") + .Input("mkl_orig_output: uint8") + .Input("mkl_grad: uint8") + .Input("mkl_workspace: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("ksize: list(int) >= 5") + .Attr("strides: list(int) >= 5") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnet3dDataFormatAttrString()) + .Attr("T: {half, bfloat16, float} = DT_FLOAT") + .Attr("TInput: {half, bfloat16, float} = DT_FLOAT") + .Attr("workspace_enabled: bool = false") + .SetShapeFn([](InferenceContext* c) { + return UnchangedShapeWithRank(c, 5); + }) + .Doc(R"doc( +MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients +of MklPool function. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + REGISTER_OP("_MklLRN") .Input("input: T") .Input("mkl_input: uint8") @@ -2176,7 +2357,7 @@ REGISTER_OP("_MklToTf") .Input("mkl_input: uint8") .Output("output: T") .Attr("T: {half, float, double}") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( MKL operator to convert a tensor from MKL layout to TensorFlow layout. @@ -2198,7 +2379,7 @@ REGISTER_OP("_MklInputConversion") .Attr( "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " "complex64, complex128}") - .Attr(GetConvnetDataFormatAttrString()) + .Attr(GetConvnetDataFormat2D3DAttrString()) .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( MKL operator to process the inputs to an elementwise MKL op. Both inputs |