diff options
author | Rohan Jain <rohanj@google.com> | 2017-12-01 14:18:39 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-01 14:24:51 -0800 |
commit | 5b420f6bb29cd0c3796dd2d7cb2aa4bf0adf620f (patch) | |
tree | df3e44acdcd90ef7b524675d261cc08a698bae78 /tensorflow/core | |
parent | ed9163acfd510c26c49201ec9e360e20a2625ca8 (diff) |
Adds a ReleaseHandle method to the FunctionLibraryRuntime interface that allows for releasing the state associated with the handle.
Also simplifies the state owned by the FunctionLibraryRuntimeImpl. Instead of having a vector of ref counted Item objects and a separate vector of function bodies, we merge it into one object that holds the entire instantiated state for the function.
PiperOrigin-RevId: 177639560
Diffstat (limited to 'tensorflow/core')
6 files changed, 138 insertions, 37 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 23d0f331c5..4c87c922c2 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -153,6 +153,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Status Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) override; + Status ReleaseHandle(Handle handle) override; + const FunctionBody* GetFunctionBody(Handle handle) override; Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override; @@ -190,18 +192,21 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { mutable mutex mu_; - // func_graphs_ never shrinks or reorders its members. - std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_); + int next_handle_ GUARDED_BY(mu_); // The instantiated and transformed function is encoded as a Graph // object, and an executor is created for the graph. struct Item : public core::RefCounted { const Graph* graph = nullptr; // Owned by exec. + FunctionBody* func_graph = nullptr; Executor* exec = nullptr; - ~Item() override { delete this->exec; } + ~Item() override { + delete this->func_graph; + delete this->exec; + } }; - std::vector<Item*> items_; + std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_); ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned. @@ -236,6 +241,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( device_name_(device_ == nullptr ? ProcessFunctionLibraryRuntime::kDefaultFLRDevice : device_->name()), + next_handle_(0), parent_(parent) { get_func_sig_ = [this](const string& op, const OpDef** sig) { return lib_def_->LookUpOpDef(op, sig); @@ -246,9 +252,9 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( } FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() { - for (FunctionBody* p : func_graphs_) delete p; - for (Item* item : items_) - if (item) item->Unref(); + for (auto item : items_) { + if (item.second) item.second->Unref(); + } } // An asynchronous op kernel which executes an instantiated function @@ -309,9 +315,8 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) { } mutex_lock l(mu_); - CHECK_LE(0, local_handle); - CHECK_LT(local_handle, func_graphs_.size()); - return func_graphs_[local_handle]; + CHECK_EQ(1, items_.count(local_handle)); + return items_[local_handle]->func_graph; } Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, @@ -478,14 +483,32 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name, if (*handle != kInvalidHandle) { delete fbody; } else { - *handle = parent_->AddHandle(key, device_name_, func_graphs_.size()); - func_graphs_.push_back(fbody); - items_.resize(func_graphs_.size()); + *handle = parent_->AddHandle(key, device_name_, next_handle_); + Item* item = new Item; + item->func_graph = fbody; + items_.insert({next_handle_, item}); + next_handle_++; } } return Status::OK(); } +Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { + if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { + return parent_->ReleaseHandle(handle); + } + + LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); + mutex_lock l(mu_); + CHECK_EQ(1, items_.count(h)); + Item* item = items_[h]; + if (item->Unref()) { + items_.erase(h); + TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle)); + } + return Status::OK(); +} + void DumpGraph(StringPiece label, const Graph* g) { // TODO(zhifengc): Change Graph to record #nodes. VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " @@ -529,7 +552,6 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { Executor* exec; TF_RETURN_IF_ERROR(NewLocalExecutor(params, g.release(), &exec)); - *item = new Item; (*item)->graph = graph; (*item)->exec = exec; return Status::OK(); @@ -539,13 +561,12 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle); { mutex_lock l(mu_); - if (local_handle >= items_.size()) { + if (items_.count(local_handle) == 0) { return errors::NotFound("Function handle ", handle, " is not valid. Likely an internal error."); } *item = items_[local_handle]; - if (*item != nullptr) { - (*item)->Ref(); + if ((*item)->exec != nullptr) { return Status::OK(); } } @@ -556,9 +577,8 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { { mutex_lock l(mu_); if (items_[local_handle] == nullptr) { - // Install *item in items_. - items_[local_handle] = *item; - (*item)->Ref(); + // Insert *item in items_. + items_.insert({local_handle, *item}); } } return Status::OK(); @@ -617,7 +637,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, *exec_args, [item, frame, rets, done, source_device, target_device, target_incarnation, rendezvous, device_context, remote_args, exec_args](const Status& status) { - item->Unref(); Status s = status; if (s.ok()) { s = frame->ConsumeRetvals(rets); @@ -701,7 +720,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, *exec_args, // Done callback. [item, frame, rets, done, exec_args](const Status& status) { - item->Unref(); Status s = status; if (s.ok()) { s = frame->ConsumeRetvals(rets); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index d183bf7c97..575af566d5 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -207,7 +207,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { return status; } FunctionLibraryRuntime::Options opts; - return Run(flr, handle, opts, args, std::move(rets)); + status = Run(flr, handle, opts, args, rets); + if (!status.ok()) return status; + + // Release the handle and try running again. It should not succeed. + status = flr->ReleaseHandle(handle); + if (!status.ok()) return status; + + Status status2 = Run(flr, handle, opts, args, std::move(rets)); + EXPECT_TRUE(errors::IsInvalidArgument(status2)); + EXPECT_TRUE( + StringPiece(status2.error_message()).contains("remote execution.")); + + return status; } std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr, @@ -498,7 +510,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) { Scope s = Scope::NewRootScope(); auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto x4_x2_scale = ops::Const<float>( - s.WithOpName("x4/x2/scale/_12__cf__2") + s.WithOpName("x4/x2/scale/_12__cf__3") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale); @@ -694,13 +706,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0); auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1); auto scale = ops::Const( - s.WithOpName("scale/_5__cf__6") + s.WithOpName("scale/_5__cf__7") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 2.0f); auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale); auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x); auto const0 = ops::Const( - s.WithOpName("Func/_1/sy/_6__cf__7") + s.WithOpName("Func/_1/sy/_6__cf__8") .WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"), 0, {0}); auto func1_rx = ops::internal::BroadcastGradientArgs( diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 142ff2339b..53a14121d4 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -30,7 +30,10 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const FunctionLibraryDefinition* lib_def, const OptimizerOptions& optimizer_options, DistributedFunctionLibraryRuntime* parent) - : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) { + : device_mgr_(device_mgr), + lib_def_(lib_def), + next_handle_(0), + parent_(parent) { if (device_mgr == nullptr) { flr_map_[nullptr] = NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version, @@ -50,7 +53,10 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( const OptimizerOptions& optimizer_options, CustomKernelCreator custom_kernel_creator, DistributedFunctionLibraryRuntime* parent) - : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) { + : device_mgr_(device_mgr), + lib_def_(lib_def), + next_handle_(0), + parent_(parent) { if (device_mgr == nullptr) { flr_map_[nullptr] = NewFunctionLibraryRuntime( nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options, @@ -185,30 +191,38 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle( FunctionLibraryRuntime::Handle h = gtl::FindWithDefault(table_, function_key, kInvalidHandle); if (h != kInvalidHandle) { - return h; + if (function_data_.count(h) != 0) return h; } - h = function_data_.size(); - function_data_.emplace_back(device_name, local_handle); + h = next_handle_; + function_data_.insert({h, FunctionData(device_name, local_handle)}); table_[function_key] = h; + next_handle_++; return h; } FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle( const string& function_key) const { mutex_lock l(mu_); - return gtl::FindWithDefault(table_, function_key, kInvalidHandle); + FunctionLibraryRuntime::Handle h = + gtl::FindWithDefault(table_, function_key, kInvalidHandle); + if (h != kInvalidHandle) { + if (function_data_.count(h) == 0) return kInvalidHandle; + } + return h; } bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice( const string& device_name, FunctionLibraryRuntime::Handle handle) { - return GetHandleOnDevice(device_name, handle) != -1; + return GetHandleOnDevice(device_name, handle) != kInvalidHandle; } FunctionLibraryRuntime::LocalHandle ProcessFunctionLibraryRuntime::GetHandleOnDevice( const string& device_name, FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); - CHECK_LE(handle, function_data_.size()); + if (function_data_.count(handle) == 0) { + return kInvalidLocalHandle; + } const FunctionData& function_data = function_data_[handle]; if (function_data.target_device != device_name) { return kInvalidLocalHandle; @@ -219,7 +233,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( string ProcessFunctionLibraryRuntime::GetDeviceName( FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); - CHECK_LE(handle, function_data_.size()); + CHECK_EQ(1, function_data_.count(handle)); const FunctionData& function_data = function_data_[handle]; return function_data.target_device; } @@ -245,6 +259,29 @@ Status ProcessFunctionLibraryRuntime::Instantiate( return Status::OK(); } +Status ProcessFunctionLibraryRuntime::RemoveHandle( + FunctionLibraryRuntime::Handle handle) { + mutex_lock l(mu_); + function_data_.erase(handle); + return Status::OK(); +} + +Status ProcessFunctionLibraryRuntime::ReleaseHandle( + FunctionLibraryRuntime::Handle handle) { + FunctionLibraryRuntime* flr = nullptr; + string target_device; + { + mutex_lock l(mu_); + CHECK_EQ(1, function_data_.count(handle)); + target_device = function_data_[handle].target_device; + } + flr = GetFLR(target_device); + if (flr != nullptr) { + return flr->ReleaseHandle(handle); + } + return errors::InvalidArgument("Handle not found: ", handle); +} + void ProcessFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args, @@ -261,7 +298,10 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime::LocalHandle local_handle; { mutex_lock l(mu_); - CHECK_LE(handle, function_data_.size()); + if (function_data_.count(handle) == 0) { + done(errors::NotFound("Handle: ", handle, " not found.")); + return; + } target_device = function_data_[handle].target_device; local_handle = function_data_[handle].local_handle; } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index a267bc3601..3aa7b87286 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -123,6 +123,12 @@ class ProcessFunctionLibraryRuntime { Status Instantiate(const string& function_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle); + // Delegates to the local FLR that owns state corresponding to `handle` and + // tells it to release it. If the `handle` isnt' needed at all, the local FLR + // might call RemoveHandle on this to get rid of the state owned by the Proc + // FLR. + Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); + // Runs the function with given `handle`. Function could have been // instantiated on any device. More details in framework/function.h void Run(const FunctionLibraryRuntime::Options& opts, @@ -140,6 +146,9 @@ class ProcessFunctionLibraryRuntime { // of the device where the function is registered. string GetDeviceName(FunctionLibraryRuntime::Handle handle); + // Removes handle from the state owned by this object. + Status RemoveHandle(FunctionLibraryRuntime::Handle handle); + friend class FunctionLibraryRuntimeImpl; mutable mutex mu_; @@ -151,6 +160,7 @@ class ProcessFunctionLibraryRuntime { FunctionData(const string& target_device, FunctionLibraryRuntime::LocalHandle local_handle) : target_device(target_device), local_handle(local_handle) {} + FunctionData() : FunctionData("", -1) {} }; const DeviceMgr* const device_mgr_; @@ -158,8 +168,10 @@ class ProcessFunctionLibraryRuntime { // Holds all the function invocations here. std::unordered_map<string, FunctionLibraryRuntime::Handle> table_ GUARDED_BY(mu_); - std::vector<FunctionData> function_data_ GUARDED_BY(mu_); + std::unordered_map<FunctionLibraryRuntime::Handle, FunctionData> + function_data_ GUARDED_BY(mu_); std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_; + int next_handle_ GUARDED_BY(mu_); DistributedFunctionLibraryRuntime* const parent_; }; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 6bc8f980c7..270e46dfe9 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -82,6 +82,22 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { EXPECT_GE(call_count, 1); // Test runner is used. + // Release the handle and then try running the function. It shouldn't + // succeed. + status = proc_flr_->ReleaseHandle(handle); + if (!status.ok()) { + return status; + } + Notification done2; + proc_flr_->Run(opts, handle, args, &out, + [&status, &done2](const Status& s) { + status = s; + done2.Notify(); + }); + done2.WaitForNotification(); + EXPECT_TRUE(errors::IsNotFound(status)); + EXPECT_TRUE(StringPiece(status.error_message()).contains("not found.")); + return Status::OK(); } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 305b140a44..d3d6358362 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -408,6 +408,9 @@ class FunctionLibraryRuntime { virtual Status Instantiate(const string& function_name, AttrSlice attrs, Handle* handle) = 0; + // Releases state associated with the handle. + virtual Status ReleaseHandle(Handle handle) = 0; + // Returns the function body for the instantiated function given its // handle 'h'. Returns nullptr if "h" is not found. // |