blob: 18473aea2e8d10d70418b754f3f5dde9d4fb5f91 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
// This file is used by cuda code and must remain compilable by nvcc.
#include "tensorflow/core/platform/port.h"
// Macros to apply another macro to lists of supported types. If you change
// the lists of types, please also update the list in types.cc.
//
// See example uses of these macros in core/ops.
//
//
// Each of these TF_CALL_XXX_TYPES(m) macros invokes the macro "m" multiple
// times by passing each invocation a data type supported by TensorFlow.
//
// The different variations pass different subsets of the types.
// TF_CALL_ALL_TYPES(m) applied "m" to all types supported by TensorFlow.
// The set of types depends on the compilation platform.
//.
// This can be used to register a different template instantiation of
// an OpKernel for different signatures, e.g.:
/*
#define REGISTER_PARTITION(type) \
REGISTER_TF_OP_KERNEL("partition", DEVICE_CPU, #type ", int32", \
PartitionOp<type>);
TF_CALL_ALL_TYPES(REGISTER_PARTITION)
#undef REGISTER_PARTITION
*/
#ifndef __ANDROID__
// Call "m" for all number types that support the comparison operations "<" and
// ">".
#define TF_CALL_REAL_NUMBER_TYPES(m) \
m(float); \
m(double); \
m(int64); \
m(int32); \
m(uint8); \
m(int16); \
m(int8)
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) \
m(float); \
m(double); \
m(int64); \
m(uint8); \
m(int16); \
m(int8)
// Call "m" for all number types, including complex64.
#define TF_CALL_NUMBER_TYPES(m) \
TF_CALL_REAL_NUMBER_TYPES(m); \
m(complex64)
#define TF_CALL_NUMBER_TYPES_NO_INT32(m) \
TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m); \
m(complex64)
// Call "m" on all types.
#define TF_CALL_ALL_TYPES(m) \
TF_CALL_NUMBER_TYPES(m); \
m(bool); \
m(string)
// Call "m" on all types supported on GPU.
#define TF_CALL_GPU_NUMBER_TYPES(m) \
m(float); \
m(double)
#else // __ANDROID__
#define TF_CALL_REAL_NUMBER_TYPES(m) \
m(float); \
m(int32)
#define TF_CALL_NUMBER_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m) m(float)
#define TF_CALL_NUMBER_TYPES_NO_INT32(m) TF_CALL_REAL_NUMBER_TYPES_NO_INT32(m)
#define TF_CALL_ALL_TYPES(m) TF_CALL_REAL_NUMBER_TYPES(m)
// Maybe we could put an empty macro here for Android?
#define TF_CALL_GPU_NUMBER_TYPES(m) m(float)
#endif // __ANDROID__
#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
|