aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/kernel_def_builder.h
blob: 0c14d1e006dea4996003296e5206f6d51d6b73e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_

#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/port.h"

namespace tensorflow {

// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
class KernelDefBuilder {
 public:
  // Starts with just the name field set.
  // Caller MUST call Build() and take ownership of the result.
  explicit KernelDefBuilder(const char* op_name);

  ~KernelDefBuilder() {
    DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
  }

  // Required: specify the type of device this kernel supports.
  // Returns *this.
  KernelDefBuilder& Device(const char* device_type);
  //  KernelDefBuilder& Device(DeviceType device_type);

  // Specify that this kernel supports a limited set of values for a
  // particular type or list(type) attr (a further restriction than
  // what the Op allows).
  // Returns *this.
  KernelDefBuilder& TypeConstraint(const char* attr_name,
                                   gtl::ArraySlice<DataType> allowed);

  // Like TypeConstraint but supports just a single type.
  KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);

  // Like TypeConstraint, but (a) gets the type from a template parameter
  // and (b) only supports a constraint to a single type.
  template <class T>
  KernelDefBuilder& TypeConstraint(const char* attr_name);
  // TODO(josh11b): Support other types of attr constraints as needed.

  // Specify that this kernel requires/provides an input/output arg
  // in host memory (instead of the default, device memory).
  // Returns *this.
  KernelDefBuilder& HostMemory(const char* arg_name);

  // Specify that this kernel requires a particular value for the
  // "_kernel" attr.  May only be specified once.  Returns *this.
  KernelDefBuilder& Label(const char* label);

  // Returns a pointer to a KernelDef with fields set based on the
  // above calls to this instance.
  // Caller takes ownership of the result.
  const KernelDef* Build() {
    KernelDef* r = kernel_def_;
    kernel_def_ = nullptr;
    return r;
  }

 private:
  KernelDef* kernel_def_;

  TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};

// IMPLEMENTATION

template <class T>
inline KernelDefBuilder& KernelDefBuilder::TypeConstraint(
    const char* attr_name) {
  return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
}

}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_