diff options
author | 2016-06-16 19:59:21 -0800 | |
---|---|---|
committer | 2016-06-16 21:03:08 -0700 | |
commit | 84d123c6ffece5a31426c9b9eb48f04696045458 (patch) | |
tree | 5656499690eadaec0f05874e6d44ed54c7dee8a4 /tensorflow/core/framework/op.cc | |
parent | 5a16e1b0d47f2ba17601c6886b8f529a8e5899f4 (diff) |
Made the OpRegistry::Register function take a factory instead of a
unique_ptr<OpRegistrationData>, so that the OpDef building can be
deferred until an call to Lookup.
Change: 125132276
Diffstat (limited to 'tensorflow/core/framework/op.cc')
-rw-r--r-- | tensorflow/core/framework/op.cc | 67 |
1 files changed, 47 insertions, 20 deletions
diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index 2ea735a790..41bb07581f 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -47,17 +47,12 @@ OpRegistry::~OpRegistry() { for (const auto& e : registry_) delete e.second; } -void OpRegistry::Register(std::unique_ptr<OpRegistrationData> op_reg_data) { - OpRegistrationData* raw_ptr = op_reg_data.get(); - +void OpRegistry::Register(OpRegistrationDataFactory op_data_factory) { mutex_lock lock(mu_); if (initialized_) { - TF_QCHECK_OK(RegisterAlreadyLocked(std::move(op_reg_data))); + TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory)); } else { - deferred_.push_back(std::move(op_reg_data)); - } - if (watcher_) { - watcher_(raw_ptr->op_def); + deferred_.push_back(op_data_factory); } } @@ -133,6 +128,21 @@ void OpRegistry::Export(bool include_internal, OpList* ops) const { } } +void OpRegistry::DeferRegistrations() { + mutex_lock lock(mu_); + initialized_ = false; +} + +void OpRegistry::ClearDeferredRegistrations() { + mutex_lock lock(mu_); + deferred_.clear(); +} + +void OpRegistry::ProcessRegistrations() const { + mutex_lock lock(mu_); + CallDeferred(); +} + string OpRegistry::DebugString(bool include_internal) const { OpList op_list; Export(include_internal, &op_list); @@ -147,23 +157,34 @@ bool OpRegistry::CallDeferred() const { if (initialized_) return false; initialized_ = true; for (int i = 0; i < deferred_.size(); ++i) { - TF_QCHECK_OK(RegisterAlreadyLocked(std::move(deferred_[i]))); + TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i])); } deferred_.clear(); return true; } Status OpRegistry::RegisterAlreadyLocked( - std::unique_ptr<OpRegistrationData> op_reg_data) const { - TF_RETURN_IF_ERROR(ValidateOpDef(op_reg_data->op_def)); - - if (gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), - op_reg_data.get())) { - op_reg_data.release(); // Ownership transferred to op_registry - return Status::OK(); + OpRegistrationDataFactory op_data_factory) const { + std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData); + Status s = op_data_factory(op_reg_data.get()); + if (s.ok()) { + s = ValidateOpDef(op_reg_data->op_def); + if (s.ok() && + !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(), + op_reg_data.get())) { + s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); + } + } + Status watcher_status = s; + if (watcher_) { + watcher_status = watcher_(s, op_reg_data->op_def); + } + if (s.ok()) { + op_reg_data.release(); } else { - return errors::AlreadyExists("Op with name ", op_reg_data->op_def.name()); + op_reg_data.reset(); } + return watcher_status; } // static @@ -202,9 +223,15 @@ Status OpListOpRegistry::LookUp(const string& op_type_name, namespace register_op { OpDefBuilderReceiver::OpDefBuilderReceiver( const OpDefBuilderWrapper<true>& wrapper) { - std::unique_ptr<OpRegistrationData> data(new OpRegistrationData); - wrapper.builder().Finalize(data.get()); - OpRegistry::Global()->Register(std::move(data)); + OpRegistry::Global()->Register( + [wrapper](OpRegistrationData* op_reg_data) -> Status { + wrapper.builder().Finalize(op_reg_data); + // TODO(keveman): Add this check back again in a separate CL. + // if (!s.ok()) { + // return s; + // } + return Status::OK(); + }); } } // namespace register_op |