1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
#ifndef TENSORFLOW_FRAMEWORK_OP_H_
#define TENSORFLOW_FRAMEWORK_OP_H_
#include <functional>
#include <unordered_map>
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/status.h"
namespace tensorflow {
// Users that want to look up an OpDef by type name should take an
// OpRegistryInterface. Functions accepting a
// (const) OpRegistryInterface* may call LookUp() from multiple threads.
class OpRegistryInterface {
public:
virtual ~OpRegistryInterface();
// Returns nullptr and sets *status if no OpDef is registered under that
// name, otherwise returns the registered OpDef.
// Caller must not delete the returned pointer.
virtual const OpDef* LookUp(const string& op_type_name,
Status* status) const = 0;
};
// The standard implementation of OpRegistryInterface, along with a
// global singleton used for registering OpDefs via the REGISTER
// macros below. Thread-safe.
//
// Example registration:
// OpRegistry::Global()->Register([]()->OpDef{
// OpDef def;
// // Populate def here.
// return def;
// });
class OpRegistry : public OpRegistryInterface {
public:
OpRegistry();
~OpRegistry() override {}
// Calls func() and registers the returned OpDef. Since Register()
// is normally called during program initialization (before main()),
// we defer calling func() until the first call to LookUp() or
// Export() (if one of those has already been called, func() is
// called immediately).
void Register(std::function<OpDef(void)> func);
const OpDef* LookUp(const string& op_type_name,
Status* status) const override;
// Fills *ops with all registered OpDefss (except those with names
// starting with '_' if include_internal == false).
void Export(bool include_internal, OpList* ops) const;
// Returns ASCII-format OpList for all registered OpDefs (except
// those with names starting with '_' if include_internal == false).
string DebugString(bool include_internal) const;
// A singleton available at startup.
static OpRegistry* Global();
private:
// Ensures that all the functions in deferred_ get called, their OpDef's
// registered, and returns with deferred_ empty. Returns true the first
// time it is called.
bool CallDeferred() const EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Add 'def' to the registry. On failure, or if there is already an
// OpDef with that name registered, returns a non-okay status.
Status RegisterAlreadyLocked(const OpDef& def) const
EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutable mutex mu_;
// Functions in deferred_ may only be called with mu_ held.
mutable std::vector<std::function<OpDef(void)>> deferred_ GUARDED_BY(mu_);
mutable std::unordered_map<string, OpDef*> registry_ GUARDED_BY(mu_);
mutable bool initialized_ GUARDED_BY(mu_);
};
// Support for defining the OpDef (specifying the semantics of the Op and how
// it should be created) and registering it in the OpRegistry::Global()
// registry. Usage:
//
// REGISTER_OP("my_op_name")
// .Attr("<name>:<type>")
// .Attr("<name>:<type>=<default>")
// .Input("<name>:<type-expr>")
// .Input("<name>:Ref(<type-expr>)")
// .Output("<name>:<type-expr>")
// .Doc(R"(
// <1-line summary>
// <rest of the description (potentially many lines)>
// <name-of-attr-input-or-output>: <description of name>
// <name-of-attr-input-or-output>: <description of name;
// if long, indent the description on subsequent lines>
// )");
//
// Note: .Doc() should be last.
// For details, see the OpDefBuilder class in op_def_builder.h.
namespace register_op {
// To call OpRegistry::Global()->Register(...), used by the
// REGISTER_OP macro below.
OpDefBuilder& RegisterOp(StringPiece name);
} // namespace register_op
#define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
#define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
#define REGISTER_OP_UNIQ(ctr, name) \
static ::tensorflow::OpDefBuilder& register_op##ctr TF_ATTRIBUTE_UNUSED = \
::tensorflow::register_op::RegisterOp(name)
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_OP_H_
|