aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/kernel_def_builder.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/kernel_def_builder.h')
-rw-r--r--tensorflow/core/framework/kernel_def_builder.h77
1 files changed, 77 insertions, 0 deletions
diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h
new file mode 100644
index 0000000000..0c14d1e006
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_builder.h
@@ -0,0 +1,77 @@
+#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+
+// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
+class KernelDefBuilder {
+ public:
+ // Starts with just the name field set.
+ // Caller MUST call Build() and take ownership of the result.
+ explicit KernelDefBuilder(const char* op_name);
+
+ ~KernelDefBuilder() {
+ DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
+ }
+
+ // Required: specify the type of device this kernel supports.
+ // Returns *this.
+ KernelDefBuilder& Device(const char* device_type);
+ // KernelDefBuilder& Device(DeviceType device_type);
+
+ // Specify that this kernel supports a limited set of values for a
+ // particular type or list(type) attr (a further restriction than
+ // what the Op allows).
+ // Returns *this.
+ KernelDefBuilder& TypeConstraint(const char* attr_name,
+ gtl::ArraySlice<DataType> allowed);
+
+ // Like TypeConstraint but supports just a single type.
+ KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
+
+ // Like TypeConstraint, but (a) gets the type from a template parameter
+ // and (b) only supports a constraint to a single type.
+ template <class T>
+ KernelDefBuilder& TypeConstraint(const char* attr_name);
+ // TODO(josh11b): Support other types of attr constraints as needed.
+
+ // Specify that this kernel requires/provides an input/output arg
+ // in host memory (instead of the default, device memory).
+ // Returns *this.
+ KernelDefBuilder& HostMemory(const char* arg_name);
+
+ // Specify that this kernel requires a particular value for the
+ // "_kernel" attr. May only be specified once. Returns *this.
+ KernelDefBuilder& Label(const char* label);
+
+ // Returns a pointer to a KernelDef with fields set based on the
+ // above calls to this instance.
+ // Caller takes ownership of the result.
+ const KernelDef* Build() {
+ KernelDef* r = kernel_def_;
+ kernel_def_ = nullptr;
+ return r;
+ }
+
+ private:
+ KernelDef* kernel_def_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
+};
+
+// IMPLEMENTATION
+
+template <class T>
+inline KernelDefBuilder& KernelDefBuilder::TypeConstraint(
+ const char* attr_name) {
+ return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_