From a5b5049e36a7d9338ce2d77fb5a3066af8a4951c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 6 Jun 2016 11:40:22 -0800 Subject: Change register_types.h to support individual TF_CALL_float, TF_CALL_half, etc. macros, and change the call-many macros to use those. This will let us change some other kernels to use TF_CALL*, for cases where no existing subset of types is what the kernel wants. This can help size on IOS. Change: 124166009 --- tensorflow/core/framework/register_types.h | 173 +++++++++++++++++------------ 1 file changed, 105 insertions(+), 68 deletions(-) diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index da8045f4d5..a843f9e654 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -20,8 +20,14 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" -// Macros to apply another macro to lists of supported types. If you change -// the lists of types, please also update the list in types.cc. +// Two sets of macros: +// - TF_CALL_float, TF_CALL_double, etc. which call the given macro with +// the type name as the only parameter - except on platforms for which +// the type should not be included. +// - Macros to apply another macro to lists of supported types. These also call +// into TF_CALL_float, TF_CALL_double, etc. so they filter by target platform +// as well. +// If you change the lists of types, please also update the list in types.cc. // // See example uses of these macros in core/ops. // @@ -44,92 +50,123 @@ limitations under the License. */ #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) -// Call "m" for all number types that support the comparison operations "<" and -// ">". -#define TF_CALL_INTEGRAL_TYPES(m) \ - m(::tensorflow::int64) m(::tensorflow::int32) m(::tensorflow::uint16) \ - m(::tensorflow::int16) m(::tensorflow::uint8) m(::tensorflow::int8) -#define TF_CALL_REAL_NUMBER_TYPES(m) \ - TF_CALL_INTEGRAL_TYPES(m) \ - m(Eigen::half) m(float) m(double) - -#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ - m(Eigen::half) m(float) m(double) m(::tensorflow::int64) \ - m(::tensorflow::uint16) m(::tensorflow::int16) m(::tensorflow::uint8) \ - m(::tensorflow::int8) - -// Call "m" for all number types, including complex64 and complex128. -#define TF_CALL_NUMBER_TYPES(m) \ - TF_CALL_REAL_NUMBER_TYPES(m) \ - m(::tensorflow::complex64) m(::tensorflow::complex128) - -#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ - TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ - m(::tensorflow::complex64) m(::tensorflow::complex128) - -#define TF_CALL_POD_TYPES(m) \ - TF_CALL_NUMBER_TYPES(m) \ - m(bool) - -// Call "m" on all types. -#define TF_CALL_ALL_TYPES(m) \ - TF_CALL_POD_TYPES(m) \ - m(::tensorflow::string) - -// Call "m" on all types supported on GPU. -#define TF_CALL_GPU_NUMBER_TYPES(m) m(Eigen::half) m(float) m(double) - -#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) \ - m(float); \ - m(double) - -// Call "m" on all quantized types. -#define TF_CALL_QUANTIZED_TYPES(m) \ - m(::tensorflow::qint8) m(::tensorflow::quint8) m(::tensorflow::qint32) +// All types are supported, so all macros are invoked. +// +// Note: macros are defined in same order as types in types.proto, for +// readability. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) m(double) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint8(m) m(::tensorflow::uint8) +#define TF_CALL_int16(m) m(::tensorflow::int16) + +#define TF_CALL_int8(m) m(::tensorflow::int8) +#define TF_CALL_string(m) m(string) +#define TF_CALL_complex64(m) m(::tensorflow::complex64) +#define TF_CALL_int64(m) m(::tensorflow::int64) +#define TF_CALL_bool(m) m(bool) + +#define TF_CALL_qint8(m) m(::tensorflow::qint8) +#define TF_CALL_quint8(m) m(::tensorflow::quint8) +#define TF_CALL_qint32(m) m(::tensorflow::qint32) +#define TF_CALL_bfloat16(m) m(::tensorflow::bfloat16) +#define TF_CALL_qint16(m) m(::tensorflow::qint16) + +#define TF_CALL_quint16(m) m(::tensorflow::quint16) +#define TF_CALL_uint16(m) m(::tensorflow::uint16) +#define TF_CALL_complex128(m) m(::tensorflow::complex128) +#define TF_CALL_half(m) m(Eigen::half) #elif defined(__ANDROID_TYPES_FULL__) -#define TF_CALL_REAL_NUMBER_TYPES(m) \ - m(Eigen::half) m(float) m(::tensorflow::int32) m(::tensorflow::int64) +// Only half, float, int32, int64, and quantized types are supported. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint8(m) +#define TF_CALL_int16(m) + +#define TF_CALL_int8(m) +#define TF_CALL_string(m) +#define TF_CALL_complex64(m) +#define TF_CALL_int64(m) m(::tensorflow::int64) +#define TF_CALL_bool(m) + +#define TF_CALL_qint8(m) m(::tensorflow::qint8) +#define TF_CALL_quint8(m) m(::tensorflow::quint8) +#define TF_CALL_qint32(m) m(::tensorflow::qint32) +#define TF_CALL_bfloat16(m) +#define TF_CALL_qint16(m) m(::tensorflow::qint16) + +#define TF_CALL_quint16(m) m(::tensorflow::quint16) +#define TF_CALL_uint16(m) +#define TF_CALL_complex128(m) +#define TF_CALL_half(m) m(Eigen::half) -#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) +#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) -#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ - m(Eigen::half) m(float) m(::tensorflow::int64) +// Only float and int32 are supported. +#define TF_CALL_float(m) m(float) +#define TF_CALL_double(m) +#define TF_CALL_int32(m) m(::tensorflow::int32) +#define TF_CALL_uint8(m) +#define TF_CALL_int16(m) -#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) +#define TF_CALL_int8(m) +#define TF_CALL_string(m) +#define TF_CALL_complex64(m) +#define TF_CALL_int64(m) +#define TF_CALL_bool(m) -#define TF_CALL_POD_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) +#define TF_CALL_qint8(m) +#define TF_CALL_quint8(m) +#define TF_CALL_qint32(m) +#define TF_CALL_bfloat16(m) +#define TF_CALL_qint16(m) -#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) +#define TF_CALL_quint16(m) +#define TF_CALL_uint16(m) +#define TF_CALL_complex128(m) +#define TF_CALL_half(m) -// Maybe we could put an empty macro here for Android? -#define TF_CALL_GPU_NUMBER_TYPES(m) m(float) m(Eigen::half) +#endif // defined(IS_MOBILE_PLATFORM) - end of TF_CALL_type defines -// Call "m" on all quantized types. -#define TF_CALL_QUANTIZED_TYPES(m) \ - m(::tensorflow::qint8) m(::tensorflow::quint8) m(::tensorflow::qint32) +// Defines for sets of types. -#else // defined(IS_MOBILE_PLATFORM) && !defined(__ANDROID_TYPES_FULL__) +#define TF_CALL_INTEGRAL_TYPES(m) \ + TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \ + TF_CALL_uint8(m) TF_CALL_int8(m) -#define TF_CALL_REAL_NUMBER_TYPES(m) m(float) m(::tensorflow::int32) +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) -#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) TF_CALL_int64(m) \ + TF_CALL_uint16(m) TF_CALL_int16(m) TF_CALL_uint8(m) TF_CALL_int8(m) -#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float) +// Call "m" for all number types, including complex64 and complex128. +#define TF_CALL_NUMBER_TYPES(m) \ + TF_CALL_REAL_NUMBER_TYPES(m) TF_CALL_complex64(m) TF_CALL_complex128(m) -#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_complex64(m) TF_CALL_complex128(m) -#define TF_CALL_POD_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) +#define TF_CALL_POD_TYPES(m) TF_CALL_NUMBER_TYPES(m) TF_CALL_bool(m) -#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) +// Call "m" on all types. +#define TF_CALL_ALL_TYPES(m) TF_CALL_POD_TYPES(m) TF_CALL_string(m) -// Maybe we could put an empty macro here for Android? -#define TF_CALL_GPU_NUMBER_TYPES(m) m(float) +// Call "m" on all types supported on GPU. +#define TF_CALL_GPU_NUMBER_TYPES(m) \ + TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m) -#define TF_CALL_QUANTIZED_TYPES(m) +#define TF_CALL_GPU_NUMBER_TYPES_NO_HALF(m) TF_CALL_float(m) TF_CALL_double(m) -#endif // defined(IS_MOBILE_PLATFORM) +// Call "m" on all quantized types. +// TODO(cwhipkey): include TF_CALL_qint16(m) TF_CALL_quint16(m) +#define TF_CALL_QUANTIZED_TYPES(m) \ + TF_CALL_qint8(m) TF_CALL_quint8(m) TF_CALL_qint32(m) #endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ -- cgit v1.2.3