aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/nn_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r--tensorflow/core/ops/nn_ops.cc12
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index de059a3e7e..a3609372a9 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -819,7 +819,7 @@ REGISTER_OP("DepthwiseConv2dNative")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {half, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
@@ -945,7 +945,7 @@ REGISTER_OP("Conv3D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {half, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
@@ -977,7 +977,7 @@ REGISTER_OP("Conv3DBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {half, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropInputV2")
@@ -1003,7 +1003,7 @@ REGISTER_OP("Conv3DBackpropFilter")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {half, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropFilterV2")
@@ -1032,7 +1032,7 @@ REGISTER_OP("Conv3DBackpropInputV2")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {half, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
@@ -1069,7 +1069,7 @@ REGISTER_OP("Conv3DBackpropFilterV2")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {half, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())