blob: 0c14d1e006dea4996003296e5206f6d51d6b73e5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/port.h"
namespace tensorflow {
// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
class KernelDefBuilder {
public:
// Starts with just the name field set.
// Caller MUST call Build() and take ownership of the result.
explicit KernelDefBuilder(const char* op_name);
~KernelDefBuilder() {
DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
}
// Required: specify the type of device this kernel supports.
// Returns *this.
KernelDefBuilder& Device(const char* device_type);
// KernelDefBuilder& Device(DeviceType device_type);
// Specify that this kernel supports a limited set of values for a
// particular type or list(type) attr (a further restriction than
// what the Op allows).
// Returns *this.
KernelDefBuilder& TypeConstraint(const char* attr_name,
gtl::ArraySlice<DataType> allowed);
// Like TypeConstraint but supports just a single type.
KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
// Like TypeConstraint, but (a) gets the type from a template parameter
// and (b) only supports a constraint to a single type.
template <class T>
KernelDefBuilder& TypeConstraint(const char* attr_name);
// TODO(josh11b): Support other types of attr constraints as needed.
// Specify that this kernel requires/provides an input/output arg
// in host memory (instead of the default, device memory).
// Returns *this.
KernelDefBuilder& HostMemory(const char* arg_name);
// Specify that this kernel requires a particular value for the
// "_kernel" attr. May only be specified once. Returns *this.
KernelDefBuilder& Label(const char* label);
// Returns a pointer to a KernelDef with fields set based on the
// above calls to this instance.
// Caller takes ownership of the result.
const KernelDef* Build() {
KernelDef* r = kernel_def_;
kernel_def_ = nullptr;
return r;
}
private:
KernelDef* kernel_def_;
TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};
// IMPLEMENTATION
template <class T>
inline KernelDefBuilder& KernelDefBuilder::TypeConstraint(
const char* attr_name) {
return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
}
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
|