aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/register_types.h
blob: 18473aea2e8d10d70418b754f3f5dde9d4fb5f91 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
// This file is used by cuda code and must remain compilable by nvcc.

#include "tensorflow/core/platform/port.h"

// Macros to apply another macro to lists of supported types.  If you change
// the lists of types, please also update the list in types.cc.
//
// See example uses of these macros in core/ops.
//
//
// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple
// times by passing each invocation a data type supported by TensorFlow.
//
// The different variations pass different subsets of the types.
// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow.
// The set of types depends on the compilation platform.
//.
// This can be used to register a different template instantiation of
// an OpKernel for different signatures, e.g.:
/*
   #define REGISTER_PARTITION(type)                                  \
     REGISTER_TF_OP_KERNEL("partition", DEVICE_CPU, #type ", int32", \
                           PartitionOp<type>);
   TF_CALL_ALL_TYPES(REGISTER_PARTITION)
   #undef REGISTER_PARTITION
*/

#ifndef __ANDROID__

// Call "m" for all number types that support the comparison operations "<" and
// ">".
#define TF_CALL_REAL_NUMBER_TYPES(m) \
  m(float);                          \
  m(double);                         \
  m(int64);                          \
  m(int32);                          \
  m(uint8);                          \
  m(int16);                          \
  m(int8)

#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
  m(float);                                   \
  m(double);                                  \
  m(int64);                                   \
  m(uint8);                                   \
  m(int16);                                   \
  m(int8)

// Call "m" for all number types, including complex64.
#define TF_CALL_NUMBER_TYPES(m) \
  TF_CALL_REAL_NUMBER_TYPES(m); \
  m(complex64)

#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
  TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m); \
  m(complex64)

// Call "m" on all types.
#define TF_CALL_ALL_TYPES(m) \
  TF_CALL_NUMBER_TYPES(m);   \
  m(bool);                   \
  m(string)

// Call "m" on all types supported on GPU.
#define TF_CALL_GPU_NUMBER_TYPES(m) \
  m(float);                         \
  m(double)

#else  // __ANDROID__

#define TF_CALL_REAL_NUMBER_TYPES(m) \
  m(float);                          \
  m(int32)

#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)

#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float)

#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)

#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)

// Maybe we could put an empty macro here for Android?
#define TF_CALL_GPU_NUMBER_TYPES(m) m(float)

#endif  // __ANDROID__

#endif  // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_