aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-04-28 15:18:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-28 16:54:56 -0700
commitbe88359aa1cd5e0221427a3626e0f690239b89de (patch)
tree58879f7db577966a2d4fd6019e33aa92a2991afa
parent561fb9b89f49fe12e594f093a7073568c74910ef (diff)
Remove or warn about nn_ops accidentally registered to ints
Several nn_ops were registered for ints accidentally. Some of these even have kernel registrations for ints which do very strange things. The ops without kernels registered have been restricted to floating point types, and warnings have been added for those with kernels. Other ops were registered with more float types than they had kernels for. Fixes #9317. Fixes #5539. Change: 154595629
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py15
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt2
-rw-r--r--tensorflow/core/kernels/BUILD15
-rw-r--r--tensorflow/core/kernels/softplus_op.cc11
-rw-r--r--tensorflow/core/kernels/softsign_op.cc11
-rw-r--r--tensorflow/core/kernels/warn_about_ints.cc32
-rw-r--r--tensorflow/core/kernels/warn_about_ints.h29
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt298
-rw-r--r--tensorflow/core/ops/nn_ops.cc48
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/softsign_op_test.py6
11 files changed, 132 insertions, 341 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index b48ad09e14..60700c5a65 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -247,7 +247,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConv(self):
height, width = 7, 9
with self.test_session():
- images = np.random.uniform(size=(5, height, width, 4))
+ images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
@@ -259,7 +259,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvNCHW(self):
height, width = 7, 9
with self.test_session():
- images = np.random.uniform(size=(5, 4, height, width))
+ images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3], data_format='NCHW')
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 32, height, width])
@@ -2780,7 +2780,7 @@ class RepeatTests(test.TestCase):
def testRepeat(self):
height, width = 3, 3
with self.test_session():
- images = np.random.uniform(size=(5, height, width, 3))
+ images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3])
self.assertEqual(output.op.name, 'Repeat/convolution_3/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, 3, 3, 32])
@@ -2815,15 +2815,6 @@ class SeparableConv2dTest(test.TestCase):
self.assertEqual(output.op.name, 'SeparableConv2d/Relu')
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
- def testCreateConvFloat64(self):
- height, width = 3, 3
- with self.test_session():
- images = random_ops.random_uniform(
- (5, height, width, 3), seed=1, dtype=dtypes.float64)
- output = layers_lib.separable_conv2d(images, 32, [3, 3], 2)
- self.assertEqual(output.op.name, 'SeparableConv2d/Relu')
- self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
-
def testCreateDepthwiseConv(self):
height, width = 3, 3
with self.test_session():
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 61fc21b24d..b727e2b888 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -224,4 +224,4 @@ tensorflow/core/ops/array_grad.cc
tensorflow/core/kernels/spacetobatch_functor.cc
tensorflow/core/kernels/spacetobatch_op.cc
tensorflow/core/kernels/batchtospace_op.cc
-
+tensorflow/core/kernels/warn_about_ints.cc
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 5702c1b1c0..f87e222afa 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -392,6 +392,15 @@ cc_library(
],
)
+cc_library(
+ name = "warn_about_ints",
+ srcs = ["warn_about_ints.cc"],
+ hdrs = ["warn_about_ints.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ ],
+)
+
# Private support libraries ---------------------------------------------------
cc_header_only_library(
@@ -2711,13 +2720,13 @@ tf_kernel_library(
tf_kernel_library(
name = "softplus_op",
prefix = "softplus_op",
- deps = NN_DEPS,
+ deps = NN_DEPS + [":warn_about_ints"],
)
tf_kernel_library(
name = "softsign_op",
prefix = "softsign_op",
- deps = NN_DEPS,
+ deps = NN_DEPS + [":warn_about_ints"],
)
tf_kernel_library(
@@ -4033,6 +4042,7 @@ filegroup(
"training_ops.h",
"transpose_functor.h",
"transpose_op.h",
+ "warn_about_ints.h",
"where_op.h",
"xent_op.h",
],
@@ -4169,6 +4179,7 @@ filegroup(
"training_ops.cc",
"transpose_functor_cpu.cc",
"transpose_op.cc",
+ "warn_about_ints.cc",
"where_op.cc",
"xent_op.cc",
":android_extended_ops_headers",
diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc
index 5650435781..494a83ed14 100644
--- a/tensorflow/core/kernels/softplus_op.cc
+++ b/tensorflow/core/kernels/softplus_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -33,7 +34,10 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
class SoftplusOp : public UnaryElementWiseOp<T, SoftplusOp<Device, T>> {
public:
- using UnaryElementWiseOp<T, SoftplusOp<Device, T>>::UnaryElementWiseOp;
+ explicit SoftplusOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, SoftplusOp<Device, T>>(context) {
+ WarnAboutInts(context);
+ }
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softplus<Device, T> functor;
@@ -46,7 +50,10 @@ template <typename Device, typename T>
class SoftplusGradOp
: public BinaryElementWiseOp<T, SoftplusGradOp<Device, T>> {
public:
- using BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>::BinaryElementWiseOp;
+ explicit SoftplusGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {
+ WarnAboutInts(context);
+ }
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc
index 33b9628b32..00ee649b17 100644
--- a/tensorflow/core/kernels/softsign_op.cc
+++ b/tensorflow/core/kernels/softsign_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -33,7 +34,10 @@ typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
class SoftsignOp : public UnaryElementWiseOp<T, SoftsignOp<Device, T>> {
public:
- using UnaryElementWiseOp<T, SoftsignOp<Device, T>>::UnaryElementWiseOp;
+ explicit SoftsignOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, SoftsignOp<Device, T>>(context) {
+ WarnAboutInts(context);
+ }
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softsign<Device, T> functor;
@@ -46,7 +50,10 @@ template <typename Device, typename T>
class SoftsignGradOp
: public BinaryElementWiseOp<T, SoftsignGradOp<Device, T>> {
public:
- using BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>::BinaryElementWiseOp;
+ explicit SoftsignGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {
+ WarnAboutInts(context);
+ }
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
diff --git a/tensorflow/core/kernels/warn_about_ints.cc b/tensorflow/core/kernels/warn_about_ints.cc
new file mode 100644
index 0000000000..fd0a889c99
--- /dev/null
+++ b/tensorflow/core/kernels/warn_about_ints.cc
@@ -0,0 +1,32 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/warn_about_ints.h"
+
+namespace tensorflow {
+
+void WarnAboutInts(OpKernelConstruction* context) {
+ DataType dtype;
+ OP_REQUIRES_OK(context, context->GetAttr("T", &dtype));
+ if (DataTypeIsInteger(dtype)) {
+ LOG(WARNING) << "Op " << context->def().name() << " of type "
+ << context->def().op() << " used with integer dtype "
+ << DataTypeString(dtype)
+ << ". This op was registered with integer support "
+ << "accidentally, and you won't like the result.";
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/warn_about_ints.h b/tensorflow/core/kernels/warn_about_ints.h
new file mode 100644
index 0000000000..20666b230e
--- /dev/null
+++ b/tensorflow/core/kernels/warn_about_ints.h
@@ -0,0 +1,29 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+#define TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+// Warn if a kernel is being created using ints
+// TODO(irving): Remove in TF 2.0 along with the bad op registrations.
+void WarnAboutInts(OpKernelConstruction* context);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 85309fba78..8e57e811f8 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -1910,12 +1910,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_UINT16
type: DT_HALF
}
}
@@ -1960,18 +1954,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -2028,18 +2010,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -2087,18 +2057,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -2159,18 +2117,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -2292,12 +2238,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_UINT16
type: DT_HALF
}
}
@@ -4090,7 +4030,6 @@ op {
list {
type: DT_HALF
type: DT_FLOAT
- type: DT_DOUBLE
}
}
}
@@ -4154,7 +4093,6 @@ op {
list {
type: DT_HALF
type: DT_FLOAT
- type: DT_DOUBLE
}
}
}
@@ -4218,7 +4156,6 @@ op {
list {
type: DT_HALF
type: DT_FLOAT
- type: DT_DOUBLE
}
}
}
@@ -4278,18 +4215,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4331,18 +4256,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4401,18 +4314,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4461,18 +4362,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4518,18 +4407,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4588,18 +4465,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4648,18 +4513,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -4705,18 +4558,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -6646,12 +6487,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_UINT16
type: DT_HALF
}
}
@@ -6678,12 +6513,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_UINT16
type: DT_HALF
}
}
@@ -8400,19 +8229,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -8486,19 +8302,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -8547,9 +8350,7 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_FLOAT
- type: DT_DOUBLE
}
}
}
@@ -8605,9 +8406,7 @@ op {
type: "type"
allowed_values {
list {
- type: DT_HALF
type: DT_FLOAT
- type: DT_DOUBLE
}
}
}
@@ -9528,17 +9327,6 @@ op {
list {
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
type: DT_HALF
}
}
@@ -10782,19 +10570,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -10850,19 +10625,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -10913,19 +10675,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -10989,19 +10738,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -11068,19 +10804,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -11093,19 +10816,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT64
- type: DT_INT32
- type: DT_UINT8
- type: DT_UINT16
- type: DT_INT16
- type: DT_INT8
- type: DT_COMPLEX64
- type: DT_COMPLEX128
- type: DT_QINT8
- type: DT_QUINT8
- type: DT_QINT32
- type: DT_HALF
}
}
}
@@ -11169,14 +10879,6 @@ op {
allowed_values {
list {
type: DT_FLOAT
- type: DT_DOUBLE
- type: DT_INT32
- type: DT_INT64
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_UINT16
- type: DT_HALF
}
}
}
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 932113bf2c..3e58669e30 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -89,7 +89,7 @@ REGISTER_OP("AvgPool")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: realnumbertype")
+ .Attr("T: {half, float, double}")
.SetShapeFn(shape_inference::AvgPoolShape)
.Doc(R"doc(
Performs average pooling on the input.
@@ -117,7 +117,7 @@ REGISTER_OP("AvgPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: realnumbertype")
+ .Attr("T: {half, float, double}")
.SetShapeFn([](InferenceContext* c) {
// NOTE(mrry): We could in principle work out the shape from the
// gradients and the attrs, but if we do not know orig_input_shape
@@ -272,7 +272,7 @@ REGISTER_OP("FusedBatchNorm")
.Output("batch_variance: T")
.Output("reserve_space_1: T")
.Output("reserve_space_2: T")
- .Attr("T: numbertype")
+ .Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
.Attr("is_training: bool = true")
@@ -348,7 +348,7 @@ REGISTER_OP("FusedBatchNormGrad")
.Output("offset_backprop: T")
.Output("reserve_space_3: T")
.Output("reserve_space_4: T")
- .Attr("T: numbertype")
+ .Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
.Attr("is_training: bool = true")
@@ -504,7 +504,7 @@ REGISTER_OP("Conv2D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
@@ -557,7 +557,7 @@ REGISTER_OP("Conv2DBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
@@ -599,7 +599,7 @@ REGISTER_OP("Conv2DBackpropFilter")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
@@ -735,7 +735,7 @@ REGISTER_OP("FusedResizeAndPadConv2D")
.Input("paddings: int32")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {float}")
.Attr("resize_align_corners: bool = false")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
@@ -777,7 +777,7 @@ REGISTER_OP("FusedPadConv2D")
.Input("paddings: int32")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {float}")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
@@ -939,7 +939,7 @@ REGISTER_OP("Conv3D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: numbertype")
+ .Attr("T: {float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
@@ -971,7 +971,7 @@ REGISTER_OP("Conv3DBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: numbertype")
+ .Attr("T: {float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropInputV2")
@@ -997,7 +997,7 @@ REGISTER_OP("Conv3DBackpropFilter")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: numbertype")
+ .Attr("T: {float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Deprecated(10, "Use Conv3DBackpropFilterV2")
@@ -1026,7 +1026,7 @@ REGISTER_OP("Conv3DBackpropInputV2")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: numbertype")
+ .Attr("T: {float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
@@ -1063,7 +1063,7 @@ REGISTER_OP("Conv3DBackpropFilterV2")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: numbertype")
+ .Attr("T: {float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
@@ -1104,7 +1104,7 @@ REGISTER_OP("AvgPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: numbertype")
+ .Attr("T: {float, double}")
.SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D average pooling on the input.
@@ -1131,7 +1131,7 @@ REGISTER_OP("AvgPool3DGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: numbertype")
+ .Attr("T: {float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -1166,7 +1166,7 @@ REGISTER_OP("MaxPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: numbertype")
+ .Attr("T: {float}")
.SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D max pooling on the input.
@@ -1190,12 +1190,12 @@ REGISTER_OP("MaxPool3DGrad")
.Input("orig_output: TInput")
.Input("grad: T")
.Output("output: T")
- .Attr("ksize: list(int) >= 5 ")
+ .Attr("ksize: list(int) >= 5")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: numbertype = DT_FLOAT")
- .Attr("TInput: numbertype = DT_FLOAT")
+ .Attr("T: {float} = DT_FLOAT")
+ .Attr("TInput: {float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
})
@@ -1226,7 +1226,7 @@ REGISTER_OP("MaxPool3DGradGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: realnumbertype")
+ .Attr("T: {float}")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
ShapeHandle unused;
@@ -1260,7 +1260,7 @@ data_format: The data format of the input and output data. With the
REGISTER_OP("L2Loss")
.Input("t: T")
.Output("output: T")
- .Attr("T: numbertype")
+ .Attr("T: {half, float, double}")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
L2 Loss.
@@ -1748,7 +1748,7 @@ backprops: The gradients:
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise.
@@ -1761,7 +1761,7 @@ REGISTER_OP("EluGrad")
.Input("gradients: T")
.Input("outputs: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes gradients for the exponential linear (Elu) operation.
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index f70f60c0f5..7e4f46c46f 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -85,6 +85,12 @@ class SoftplusTest(test.TestCase):
print("softplus (float) gradient err = ", err)
self.assertLess(err, 1e-4)
+ def testWarnInts(self):
+ # NOTE(irving): Actually I don't know how to intercept the warning, but
+ # let's make sure it runs. I promised I've looked, and there was a warning.
+ with self.test_session():
+ nn_ops.softplus(constant_op.constant(7)).eval()
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index 5fd5253c09..371f86ff15 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -65,6 +65,12 @@ class SoftsignTest(test.TestCase):
print("softsign (float) gradient err = ", err)
self.assertLess(err, 1e-4)
+ def testWarnInts(self):
+ # NOTE(irving): Actually I don't know how to intercept the warning, but
+ # let's make sure it runs. I promised I've looked, and there was a warning.
+ with self.test_session():
+ nn_ops.softsign(constant_op.constant(7)).eval()
+
if __name__ == "__main__":
test.main()