diff options
Diffstat (limited to 'tensorflow/core/framework/register_types.h')
-rw-r--r-- | tensorflow/core/framework/register_types.h | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h new file mode 100644 index 0000000000..18473aea2e --- /dev/null +++ b/tensorflow/core/framework/register_types.h @@ -0,0 +1,90 @@ +#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ +// This file is used by cuda code and must remain compilable by nvcc. + +#include "tensorflow/core/platform/port.h" + +// Macros to apply another macro to lists of supported types. If you change +// the lists of types, please also update the list in types.cc. +// +// See example uses of these macros in core/ops. +// +// +// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple +// times by passing each invocation a data type supported by TensorFlow. +// +// The different variations pass different subsets of the types. +// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow. +// The set of types depends on the compilation platform. +//. +// This can be used to register a different template instantiation of +// an OpKernel for different signatures, e.g.: +/* + #define REGISTER_PARTITION(type) \ + REGISTER_TF_OP_KERNEL("partition", DEVICE_CPU, #type ", int32", \ + PartitionOp<type>); + TF_CALL_ALL_TYPES(REGISTER_PARTITION) + #undef REGISTER_PARTITION +*/ + +#ifndef __ANDROID__ + +// Call "m" for all number types that support the comparison operations "<" and +// ">". +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + m(float); \ + m(double); \ + m(int64); \ + m(int32); \ + m(uint8); \ + m(int16); \ + m(int8) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \ + m(float); \ + m(double); \ + m(int64); \ + m(uint8); \ + m(int16); \ + m(int8) + +// Call "m" for all number types, including complex64. +#define TF_CALL_NUMBER_TYPES(m) \ + TF_CALL_REAL_NUMBER_TYPES(m); \ + m(complex64) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \ + TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m); \ + m(complex64) + +// Call "m" on all types. +#define TF_CALL_ALL_TYPES(m) \ + TF_CALL_NUMBER_TYPES(m); \ + m(bool); \ + m(string) + +// Call "m" on all types supported on GPU. +#define TF_CALL_GPU_NUMBER_TYPES(m) \ + m(float); \ + m(double) + +#else // __ANDROID__ + +#define TF_CALL_REAL_NUMBER_TYPES(m) \ + m(float); \ + m(int32) + +#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) + +#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float) + +#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) + +#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m) + +// Maybe we could put an empty macro here for Android? +#define TF_CALL_GPU_NUMBER_TYPES(m) m(float) + +#endif // __ANDROID__ + +#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_ |