diff options
author | 2017-08-17 11:33:14 -0700 | |
---|---|---|
committer | 2017-08-17 11:36:50 -0700 | |
commit | 935ff49201edd7a6297b313fb9545d1299b9a28d (patch) | |
tree | 36486015014d33efa99d7fd0875eb1545bd518cb /tensorflow/core/common_runtime/function.cc | |
parent | d94dca2174f0c05dfa03796c3ae31d345813d025 (diff) |
Automated g4 rollback of changelist 165521057
PiperOrigin-RevId: 165604864
Diffstat (limited to 'tensorflow/core/common_runtime/function.cc')
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 83 |
1 files changed, 30 insertions, 53 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 829fba780f..6b529d8f13 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -139,14 +139,15 @@ static Node* AddRet(Graph* g, Endpoint input, int index) { return ret; } +static const FunctionLibraryRuntime::Handle kInvalidHandle = -1; + class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { public: FunctionLibraryRuntimeImpl(const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator, - ProcessFunctionLibraryRuntime* parent); + CustomKernelCreator custom_kernel_creator); ~FunctionLibraryRuntimeImpl() override; @@ -183,13 +184,17 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { const FunctionLibraryDefinition* const lib_def_; GraphOptimizer optimizer_; const CustomKernelCreator custom_kernel_creator_; - const string device_name_; std::function<Status(const string&, const OpDef**)> get_func_sig_; std::function<Status(const NodeDef&, OpKernel**)> create_kernel_; mutable mutex mu_; + // Maps function instantiation to a handle. The key is a + // canonicalized representation of the function name and + // instantiation attrs. The handle is an index into the items_. + std::unordered_map<string, Handle> table_ GUARDED_BY(mu_); + // func_graphs_ never shrinks or reorders its members. std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_); @@ -203,15 +208,12 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { }; std::vector<Item*> items_; - ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned. - Status FunctionDefToBody(const FunctionDef& fdef, AttrSlice attrs, FunctionBody** fbody); Status CreateItem(Handle handle, Item** item); Status GetOrCreateItem(Handle handle, Item** item); Status InstantiateSymbolicGradient(const NameAttrList& func, FunctionBody** g_body); - bool IsLocalTarget(const AttrSlice& attrs); TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); }; @@ -220,19 +222,14 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( const DeviceMgr* dmgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator, - ProcessFunctionLibraryRuntime* parent) + CustomKernelCreator custom_kernel_creator) : device_mgr_(dmgr), device_(device), env_(env), graph_def_version_(graph_def_version), lib_def_(lib_def), optimizer_(optimizer_options), - custom_kernel_creator_(std::move(custom_kernel_creator)), - device_name_(device_ == nullptr - ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice - : device_->name()), - parent_(parent) { + custom_kernel_creator_(std::move(custom_kernel_creator)) { get_func_sig_ = [this](const string& op, const OpDef** sig) { return lib_def_->LookUpOpDef(op, sig); }; @@ -297,17 +294,10 @@ class CallOp : public AsyncOpKernel { }; const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { - LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, h); - if (local_handle == kInvalidLocalHandle) { - LOG(ERROR) << "Could not find Handle: " << h - << " on device: " << device_name_; - return nullptr; - } - mutex_lock l(mu_); - CHECK_LE(0, local_handle); - CHECK_LT(local_handle, func_graphs_.size()); - return func_graphs_[local_handle]; + CHECK_LE(static_cast<Handle>(0), h); + CHECK_LT(h, func_graphs_.size()); + return func_graphs_[h]; } Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, @@ -403,24 +393,16 @@ Status FunctionLibraryRuntimeImpl::InstantiateSymbolicGradient( return Status::OK(); } -bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) { - if (device_ == nullptr) return true; - string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs); - if (target.empty()) return true; - return target == device_->name(); -} - Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) { - if (!IsLocalTarget(attrs)) { - return parent_->Instantiate(function_name, attrs, handle); - } - const string key = Canonicalize(function_name, attrs); - *handle = parent_->GetHandle(key); - if (*handle != kInvalidHandle) { - return Status::OK(); + { + mutex_lock l(mu_); + *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); + if (*handle != kInvalidHandle) { + return Status::OK(); + } } Status s; @@ -449,11 +431,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, { mutex_lock l(mu_); - *handle = parent_->GetHandle(key); + *handle = gtl::FindWithDefault(table_, key, kInvalidHandle); if (*handle != kInvalidHandle) { delete fbody; } else { - *handle = parent_->AddHandle(key, device_name_, func_graphs_.size()); + *handle = func_graphs_.size(); + table_.insert({key, *handle}); func_graphs_.push_back(fbody); items_.resize(func_graphs_.size()); } @@ -511,14 +494,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { } Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { - LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); { mutex_lock l(mu_); - if (local_handle >= items_.size()) { + if (handle >= items_.size()) { return errors::NotFound("Function handle ", handle, " is not valid. Likely an internal error."); } - *item = items_[local_handle]; + *item = items_[handle]; if (*item != nullptr) { (*item)->Ref(); return Status::OK(); @@ -530,9 +512,9 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { { mutex_lock l(mu_); - if (items_[local_handle] == nullptr) { + if (items_[handle] == nullptr) { // Install *item in items_. - items_[local_handle] = *item; + items_[handle] = *item; (*item)->Ref(); } } @@ -546,9 +528,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { return done(errors::Cancelled("")); } - if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { - return parent_->Run(opts, handle, args, rets, done); - } const FunctionBody* fbody = GetFunctionBody(handle); FunctionCallFrame* frame = new FunctionCallFrame(fbody->arg_types, fbody->ret_types); @@ -637,21 +616,19 @@ std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, - CustomKernelCreator custom_kernel_creator, - ProcessFunctionLibraryRuntime* parent) { + CustomKernelCreator custom_kernel_creator) { return std::unique_ptr<FunctionLibraryRuntime>(new FunctionLibraryRuntimeImpl( device_mgr, env, device, graph_def_version, lib_def, optimizer_options, - std::move(custom_kernel_creator), parent)); + std::move(custom_kernel_creator))); } std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime( const DeviceMgr* device_mgr, Env* env, Device* device, int graph_def_version, const FunctionLibraryDefinition* lib_def, - const OptimizerOptions& optimizer_options, - ProcessFunctionLibraryRuntime* parent) { + const OptimizerOptions& optimizer_options) { return NewFunctionLibraryRuntime(device_mgr, env, device, graph_def_version, lib_def, optimizer_options, - GetCustomCreatorSingleton()->Get(), parent); + GetCustomCreatorSingleton()->Get()); } bool RemoveDeadNodes(Graph* g) { |