aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2016-06-16 19:59:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-16 21:03:08 -0700
commit84d123c6ffece5a31426c9b9eb48f04696045458 (patch)
tree5656499690eadaec0f05874e6d44ed54c7dee8a4 /tensorflow/core/framework/op.cc
parent5a16e1b0d47f2ba17601c6886b8f529a8e5899f4 (diff)
Made the OpRegistry::Register function take a factory instead of a
unique_ptr<OpRegistrationData>, so that the OpDef building can be deferred until an call to Lookup. Change: 125132276
Diffstat (limited to 'tensorflow/core/framework/op.cc')
-rw-r--r--tensorflow/core/framework/op.cc67
1 files changed, 47 insertions, 20 deletions
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc
index 2ea735a790..41bb07581f 100644
--- a/tensorflow/core/framework/op.cc
+++ b/tensorflow/core/framework/op.cc
@@ -47,17 +47,12 @@ OpRegistry::~OpRegistry() {
for (const auto& e : registry_) delete e.second;
}
-void OpRegistry::Register(std::unique_ptr<OpRegistrationData> op_reg_data) {
- OpRegistrationData* raw_ptr = op_reg_data.get();
-
+void OpRegistry::Register(OpRegistrationDataFactory op_data_factory) {
mutex_lock lock(mu_);
if (initialized_) {
- TF_QCHECK_OK(RegisterAlreadyLocked(std::move(op_reg_data)));
+ TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
} else {
- deferred_.push_back(std::move(op_reg_data));
- }
- if (watcher_) {
- watcher_(raw_ptr->op_def);
+ deferred_.push_back(op_data_factory);
}
}
@@ -133,6 +128,21 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const {
}
}
+void OpRegistry::DeferRegistrations() {
+ mutex_lock lock(mu_);
+ initialized_ = false;
+}
+
+void OpRegistry::ClearDeferredRegistrations() {
+ mutex_lock lock(mu_);
+ deferred_.clear();
+}
+
+void OpRegistry::ProcessRegistrations() const {
+ mutex_lock lock(mu_);
+ CallDeferred();
+}
+
string OpRegistry::DebugString(bool include_internal) const {
OpList op_list;
Export(include_internal, &op_list);
@@ -147,23 +157,34 @@ bool OpRegistry::CallDeferred() const {
if (initialized_) return false;
initialized_ = true;
for (int i = 0; i < deferred_.size(); ++i) {
- TF_QCHECK_OK(RegisterAlreadyLocked(std::move(deferred_[i])));
+ TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
}
deferred_.clear();
return true;
}
Status OpRegistry::RegisterAlreadyLocked(
- std::unique_ptr<OpRegistrationData> op_reg_data) const {
- TF_RETURN_IF_ERROR(ValidateOpDef(op_reg_data->op_def));
-
- if (gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
- op_reg_data.get())) {
- op_reg_data.release(); // Ownership transferred to op_registry
- return Status::OK();
+ OpRegistrationDataFactory op_data_factory) const {
+ std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
+ Status s = op_data_factory(op_reg_data.get());
+ if (s.ok()) {
+ s = ValidateOpDef(op_reg_data->op_def);
+ if (s.ok() &&
+ !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
+ op_reg_data.get())) {
+ s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
+ }
+ }
+ Status watcher_status = s;
+ if (watcher_) {
+ watcher_status = watcher_(s, op_reg_data->op_def);
+ }
+ if (s.ok()) {
+ op_reg_data.release();
} else {
- return errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
+ op_reg_data.reset();
}
+ return watcher_status;
}
// static
@@ -202,9 +223,15 @@ Status OpListOpRegistry::LookUp(const string& op_type_name,
namespace register_op {
OpDefBuilderReceiver::OpDefBuilderReceiver(
const OpDefBuilderWrapper<true>& wrapper) {
- std::unique_ptr<OpRegistrationData> data(new OpRegistrationData);
- wrapper.builder().Finalize(data.get());
- OpRegistry::Global()->Register(std::move(data));
+ OpRegistry::Global()->Register(
+ [wrapper](OpRegistrationData* op_reg_data) -> Status {
+ wrapper.builder().Finalize(op_reg_data);
+ // TODO(keveman): Add this check back again in a separate CL.
+ // if (!s.ok()) {
+ // return s;
+ // }
+ return Status::OK();
+ });
}
} // namespace register_op