aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-06-06 11:40:22 -0800
committerGravatar Martin Wicke <wicke@google.com>2016-06-12 21:27:29 -0700
commita5b5049e36a7d9338ce2d77fb5a3066af8a4951c (patch)
treee0b1497b3eedba9267c445ca7eaa97eead73bb31
parent9216497c34c0351843f6cb2abc80c55ea9211c00 (diff)
Change register_types.h to support individual TF_CALL_float, TF_CALL_half, etc.
macros, and change the call-many macros to use those. This will let us change some other kernels to use TF_CALL*, for cases where no existing subset of types is what the kernel wants. This can help size on IOS. Change: 124166009
-rw-r--r--tensorflow/core/framework/register_types.h173
1 files changed, 105 insertions, 68 deletions
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index da8045f4d5..a843f9e654 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -20,8 +20,14 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/types.h"
-// Macros to apply another macro to lists of supported types. If you change
-// the lists of types, please also update the list in types.cc.
+// Two sets of macros:
+// - TF_CALL_float, TF_CALL_double, etc. which call the given macro with
+// the type name as the only parameter - except on platforms for which
+// the type should not be included.
+// - Macros to apply another macro to lists of supported types. These also call
+// into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform
+// as well.
+// If you change the lists of types, please also update the list in types.cc.
//
// See example uses of these macros in core/ops.
//
@@ -44,92 +50,123 @@ limitations under the License.
*/
#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
-// Call "m" for all number types that support the comparison operations "<" and
-// ">".
-#define TF_CALL_INTEGRAL_TYPES(m) \
- m(::tensorflow::int64) m(::tensorflow::int32) m(::tensorflow::uint16) \
- m(::tensorflow::int16) m(::tensorflow::uint8) m(::tensorflow::int8)
-#define TF_CALL_REAL_NUMBER_TYPES(m) \
- TF_CALL_INTEGRAL_TYPES(m) \
- m(Eigen::half) m(float) m(double)
-
-#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
- m(Eigen::half) m(float) m(double) m(::tensorflow::int64) \
- m(::tensorflow::uint16) m(::tensorflow::int16) m(::tensorflow::uint8) \
- m(::tensorflow::int8)
-
-// Call "m" for all number types, including complex64 and complex128.
-#define TF_CALL_NUMBER_TYPES(m) \
- TF_CALL_REAL_NUMBER_TYPES(m) \
- m(::tensorflow::complex64) m(::tensorflow::complex128)
-
-#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
- TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
- m(::tensorflow::complex64) m(::tensorflow::complex128)
-
-#define TF_CALL_POD_TYPES(m) \
- TF_CALL_NUMBER_TYPES(m) \
- m(bool)
-
-// Call "m" on all types.
-#define TF_CALL_ALL_TYPES(m) \
- TF_CALL_POD_TYPES(m) \
- m(::tensorflow::string)
-
-// Call "m" on all types supported on GPU.
-#define TF_CALL_GPU_NUMBER_TYPES(m) m(Eigen::half) m(float) m(double)
-
-#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) \
- m(float); \
- m(double)
-
-// Call "m" on all quantized types.
-#define TF_CALL_QUANTIZED_TYPES(m) \
- m(::tensorflow::qint8) m(::tensorflow::quint8) m(::tensorflow::qint32)
+// All types are supported, so all macros are invoked.
+//
+// Note: macros are defined in same order as types in types.proto, for
+// readability.
+#define TF_CALL_float(m) m(float)
+#define TF_CALL_double(m) m(double)
+#define TF_CALL_int32(m) m(::tensorflow::int32)
+#define TF_CALL_uint8(m) m(::tensorflow::uint8)
+#define TF_CALL_int16(m) m(::tensorflow::int16)
+
+#define TF_CALL_int8(m) m(::tensorflow::int8)
+#define TF_CALL_string(m) m(string)
+#define TF_CALL_complex64(m) m(::tensorflow::complex64)
+#define TF_CALL_int64(m) m(::tensorflow::int64)
+#define TF_CALL_bool(m) m(bool)
+
+#define TF_CALL_qint8(m) m(::tensorflow::qint8)
+#define TF_CALL_quint8(m) m(::tensorflow::quint8)
+#define TF_CALL_qint32(m) m(::tensorflow::qint32)
+#define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16)
+#define TF_CALL_qint16(m) m(::tensorflow::qint16)
+
+#define TF_CALL_quint16(m) m(::tensorflow::quint16)
+#define TF_CALL_uint16(m) m(::tensorflow::uint16)
+#define TF_CALL_complex128(m) m(::tensorflow::complex128)
+#define TF_CALL_half(m) m(Eigen::half)
#elif defined(__ANDROID_TYPES_FULL__)
-#define TF_CALL_REAL_NUMBER_TYPES(m) \
- m(Eigen::half) m(float) m(::tensorflow::int32) m(::tensorflow::int64)
+// Only half, float, int32, int64, and quantized types are supported.
+#define TF_CALL_float(m) m(float)
+#define TF_CALL_double(m)
+#define TF_CALL_int32(m) m(::tensorflow::int32)
+#define TF_CALL_uint8(m)
+#define TF_CALL_int16(m)
+
+#define TF_CALL_int8(m)
+#define TF_CALL_string(m)
+#define TF_CALL_complex64(m)
+#define TF_CALL_int64(m) m(::tensorflow::int64)
+#define TF_CALL_bool(m)
+
+#define TF_CALL_qint8(m) m(::tensorflow::qint8)
+#define TF_CALL_quint8(m) m(::tensorflow::quint8)
+#define TF_CALL_qint32(m) m(::tensorflow::qint32)
+#define TF_CALL_bfloat16(m)
+#define TF_CALL_qint16(m) m(::tensorflow::qint16)
+
+#define TF_CALL_quint16(m) m(::tensorflow::quint16)
+#define TF_CALL_uint16(m)
+#define TF_CALL_complex128(m)
+#define TF_CALL_half(m) m(Eigen::half)
-#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
-#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
- m(Eigen::half) m(float) m(::tensorflow::int64)
+// Only float and int32 are supported.
+#define TF_CALL_float(m) m(float)
+#define TF_CALL_double(m)
+#define TF_CALL_int32(m) m(::tensorflow::int32)
+#define TF_CALL_uint8(m)
+#define TF_CALL_int16(m)
-#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)
+#define TF_CALL_int8(m)
+#define TF_CALL_string(m)
+#define TF_CALL_complex64(m)
+#define TF_CALL_int64(m)
+#define TF_CALL_bool(m)
-#define TF_CALL_POD_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+#define TF_CALL_qint8(m)
+#define TF_CALL_quint8(m)
+#define TF_CALL_qint32(m)
+#define TF_CALL_bfloat16(m)
+#define TF_CALL_qint16(m)
-#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+#define TF_CALL_quint16(m)
+#define TF_CALL_uint16(m)
+#define TF_CALL_complex128(m)
+#define TF_CALL_half(m)
-// Maybe we could put an empty macro here for Android?
-#define TF_CALL_GPU_NUMBER_TYPES(m) m(float) m(Eigen::half)
+#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines
-// Call "m" on all quantized types.
-#define TF_CALL_QUANTIZED_TYPES(m) \
- m(::tensorflow::qint8) m(::tensorflow::quint8) m(::tensorflow::qint32)
+// Defines for sets of types.
-#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__)
+#define TF_CALL_INTEGRAL_TYPES(m) \
+ TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
+ TF_CALL_uint8(m) TF_CALL_int8(m)
-#define TF_CALL_REAL_NUMBER_TYPES(m) m(float) m(::tensorflow::int32)
+#define TF_CALL_REAL_NUMBER_TYPES(m) \
+ TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
-#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
+ TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_int64(m) \
+ TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m)
-#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float)
+// Call "m" for all number types, including complex64 and complex128.
+#define TF_CALL_NUMBER_TYPES(m) \
+ TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m)
-#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)
+#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
+ TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
+ TF_CALL_complex64(m) TF_CALL_complex128(m)
-#define TF_CALL_POD_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+#define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m)
-#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
+// Call "m" on all types.
+#define TF_CALL_ALL_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m)
-// Maybe we could put an empty macro here for Android?
-#define TF_CALL_GPU_NUMBER_TYPES(m) m(float)
+// Call "m" on all types supported on GPU.
+#define TF_CALL_GPU_NUMBER_TYPES(m) \
+ TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
-#define TF_CALL_QUANTIZED_TYPES(m)
+#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m)
-#endif // defined(IS_MOBILE_PLATFORM)
+// Call "m" on all quantized types.
+// TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m)
+#define TF_CALL_QUANTIZED_TYPES(m) \
+ TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m)
#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_