aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-12-01 14:18:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 14:24:51 -0800
commit5b420f6bb29cd0c3796dd2d7cb2aa4bf0adf620f (patch)
treedf3e44acdcd90ef7b524675d261cc08a698bae78 /tensorflow/core
parented9163acfd510c26c49201ec9e360e20a2625ca8 (diff)
Adds a ReleaseHandle method to the FunctionLibraryRuntime interface that allows for releasing the state associated with the handle.
Also simplifies the state owned by the FunctionLibraryRuntimeImpl. Instead of having a vector of ref counted Item objects and a separate vector of function bodies, we merge it into one object that holds the entire instantiated state for the function. PiperOrigin-RevId: 177639560
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/function.cc62
-rw-r--r--tensorflow/core/common_runtime/function_test.cc20
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc60
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h14
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc16
-rw-r--r--tensorflow/core/framework/function.h3
6 files changed, 138 insertions, 37 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 23d0f331c5..4c87c922c2 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -153,6 +153,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Status Instantiate(const string& function_name, AttrSlice attrs,
Handle* handle) override;
+ Status ReleaseHandle(Handle handle) override;
+
const FunctionBody* GetFunctionBody(Handle handle) override;
Status CreateKernel(const NodeDef& ndef, OpKernel** kernel) override;
@@ -190,18 +192,21 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
mutable mutex mu_;
- // func_graphs_ never shrinks or reorders its members.
- std::vector<FunctionBody*> func_graphs_ GUARDED_BY(mu_);
+ int next_handle_ GUARDED_BY(mu_);
// The instantiated and transformed function is encoded as a Graph
// object, and an executor is created for the graph.
struct Item : public core::RefCounted {
const Graph* graph = nullptr; // Owned by exec.
+ FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
- ~Item() override { delete this->exec; }
+ ~Item() override {
+ delete this->func_graph;
+ delete this->exec;
+ }
};
- std::vector<Item*> items_;
+ std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
@@ -236,6 +241,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
device_name_(device_ == nullptr
? ProcessFunctionLibraryRuntime::kDefaultFLRDevice
: device_->name()),
+ next_handle_(0),
parent_(parent) {
get_func_sig_ = [this](const string& op, const OpDef** sig) {
return lib_def_->LookUpOpDef(op, sig);
@@ -246,9 +252,9 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
}
FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
- for (FunctionBody* p : func_graphs_) delete p;
- for (Item* item : items_)
- if (item) item->Unref();
+ for (auto item : items_) {
+ if (item.second) item.second->Unref();
+ }
}
// An asynchronous op kernel which executes an instantiated function
@@ -309,9 +315,8 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
}
mutex_lock l(mu_);
- CHECK_LE(0, local_handle);
- CHECK_LT(local_handle, func_graphs_.size());
- return func_graphs_[local_handle];
+ CHECK_EQ(1, items_.count(local_handle));
+ return items_[local_handle]->func_graph;
}
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
@@ -478,14 +483,32 @@ Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
if (*handle != kInvalidHandle) {
delete fbody;
} else {
- *handle = parent_->AddHandle(key, device_name_, func_graphs_.size());
- func_graphs_.push_back(fbody);
- items_.resize(func_graphs_.size());
+ *handle = parent_->AddHandle(key, device_name_, next_handle_);
+ Item* item = new Item;
+ item->func_graph = fbody;
+ items_.insert({next_handle_, item});
+ next_handle_++;
}
}
return Status::OK();
}
+Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
+ if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
+ return parent_->ReleaseHandle(handle);
+ }
+
+ LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
+ mutex_lock l(mu_);
+ CHECK_EQ(1, items_.count(h));
+ Item* item = items_[h];
+ if (item->Unref()) {
+ items_.erase(h);
+ TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
+ }
+ return Status::OK();
+}
+
void DumpGraph(StringPiece label, const Graph* g) {
// TODO(zhifengc): Change Graph to record #nodes.
VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
@@ -529,7 +552,6 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
Executor* exec;
TF_RETURN_IF_ERROR(NewLocalExecutor(params, g.release(), &exec));
- *item = new Item;
(*item)->graph = graph;
(*item)->exec = exec;
return Status::OK();
@@ -539,13 +561,12 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
{
mutex_lock l(mu_);
- if (local_handle >= items_.size()) {
+ if (items_.count(local_handle) == 0) {
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
*item = items_[local_handle];
- if (*item != nullptr) {
- (*item)->Ref();
+ if ((*item)->exec != nullptr) {
return Status::OK();
}
}
@@ -556,9 +577,8 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
{
mutex_lock l(mu_);
if (items_[local_handle] == nullptr) {
- // Install *item in items_.
- items_[local_handle] = *item;
- (*item)->Ref();
+ // Insert *item in items_.
+ items_.insert({local_handle, *item});
}
}
return Status::OK();
@@ -617,7 +637,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
*exec_args, [item, frame, rets, done, source_device, target_device,
target_incarnation, rendezvous, device_context,
remote_args, exec_args](const Status& status) {
- item->Unref();
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -701,7 +720,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
*exec_args,
// Done callback.
[item, frame, rets, done, exec_args](const Status& status) {
- item->Unref();
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index d183bf7c97..575af566d5 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -207,7 +207,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
FunctionLibraryRuntime::Options opts;
- return Run(flr, handle, opts, args, std::move(rets));
+ status = Run(flr, handle, opts, args, rets);
+ if (!status.ok()) return status;
+
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ StringPiece(status2.error_message()).contains("remote execution."));
+
+ return status;
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
@@ -498,7 +510,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__2")
+ s.WithOpName("x4/x2/scale/_12__cf__3")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -694,13 +706,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_5__cf__6")
+ s.WithOpName("scale/_5__cf__7")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_6__cf__7")
+ s.WithOpName("Func/_1/sy/_6__cf__8")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 142ff2339b..53a14121d4 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -30,7 +30,10 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const FunctionLibraryDefinition* lib_def,
const OptimizerOptions& optimizer_options,
DistributedFunctionLibraryRuntime* parent)
- : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) {
+ : device_mgr_(device_mgr),
+ lib_def_(lib_def),
+ next_handle_(0),
+ parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[nullptr] =
NewFunctionLibraryRuntime(nullptr, env, nullptr, graph_def_version,
@@ -50,7 +53,10 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
DistributedFunctionLibraryRuntime* parent)
- : device_mgr_(device_mgr), lib_def_(lib_def), parent_(parent) {
+ : device_mgr_(device_mgr),
+ lib_def_(lib_def),
+ next_handle_(0),
+ parent_(parent) {
if (device_mgr == nullptr) {
flr_map_[nullptr] = NewFunctionLibraryRuntime(
nullptr, env, nullptr, graph_def_version, lib_def, optimizer_options,
@@ -185,30 +191,38 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
FunctionLibraryRuntime::Handle h =
gtl::FindWithDefault(table_, function_key, kInvalidHandle);
if (h != kInvalidHandle) {
- return h;
+ if (function_data_.count(h) != 0) return h;
}
- h = function_data_.size();
- function_data_.emplace_back(device_name, local_handle);
+ h = next_handle_;
+ function_data_.insert({h, FunctionData(device_name, local_handle)});
table_[function_key] = h;
+ next_handle_++;
return h;
}
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
const string& function_key) const {
mutex_lock l(mu_);
- return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
+ FunctionLibraryRuntime::Handle h =
+ gtl::FindWithDefault(table_, function_key, kInvalidHandle);
+ if (h != kInvalidHandle) {
+ if (function_data_.count(h) == 0) return kInvalidHandle;
+ }
+ return h;
}
bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
- return GetHandleOnDevice(device_name, handle) != -1;
+ return GetHandleOnDevice(device_name, handle) != kInvalidHandle;
}
FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
- CHECK_LE(handle, function_data_.size());
+ if (function_data_.count(handle) == 0) {
+ return kInvalidLocalHandle;
+ }
const FunctionData& function_data = function_data_[handle];
if (function_data.target_device != device_name) {
return kInvalidLocalHandle;
@@ -219,7 +233,7 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
- CHECK_LE(handle, function_data_.size());
+ CHECK_EQ(1, function_data_.count(handle));
const FunctionData& function_data = function_data_[handle];
return function_data.target_device;
}
@@ -245,6 +259,29 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
return Status::OK();
}
+Status ProcessFunctionLibraryRuntime::RemoveHandle(
+ FunctionLibraryRuntime::Handle handle) {
+ mutex_lock l(mu_);
+ function_data_.erase(handle);
+ return Status::OK();
+}
+
+Status ProcessFunctionLibraryRuntime::ReleaseHandle(
+ FunctionLibraryRuntime::Handle handle) {
+ FunctionLibraryRuntime* flr = nullptr;
+ string target_device;
+ {
+ mutex_lock l(mu_);
+ CHECK_EQ(1, function_data_.count(handle));
+ target_device = function_data_[handle].target_device;
+ }
+ flr = GetFLR(target_device);
+ if (flr != nullptr) {
+ return flr->ReleaseHandle(handle);
+ }
+ return errors::InvalidArgument("Handle not found: ", handle);
+}
+
void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
@@ -261,7 +298,10 @@ void ProcessFunctionLibraryRuntime::Run(
FunctionLibraryRuntime::LocalHandle local_handle;
{
mutex_lock l(mu_);
- CHECK_LE(handle, function_data_.size());
+ if (function_data_.count(handle) == 0) {
+ done(errors::NotFound("Handle: ", handle, " not found."));
+ return;
+ }
target_device = function_data_[handle].target_device;
local_handle = function_data_[handle].local_handle;
}
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index a267bc3601..3aa7b87286 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -123,6 +123,12 @@ class ProcessFunctionLibraryRuntime {
Status Instantiate(const string& function_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle);
+ // Delegates to the local FLR that owns state corresponding to `handle` and
+ // tells it to release it. If the `handle` isnt' needed at all, the local FLR
+ // might call RemoveHandle on this to get rid of the state owned by the Proc
+ // FLR.
+ Status ReleaseHandle(FunctionLibraryRuntime::Handle handle);
+
// Runs the function with given `handle`. Function could have been
// instantiated on any device. More details in framework/function.h
void Run(const FunctionLibraryRuntime::Options& opts,
@@ -140,6 +146,9 @@ class ProcessFunctionLibraryRuntime {
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
+ // Removes handle from the state owned by this object.
+ Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
+
friend class FunctionLibraryRuntimeImpl;
mutable mutex mu_;
@@ -151,6 +160,7 @@ class ProcessFunctionLibraryRuntime {
FunctionData(const string& target_device,
FunctionLibraryRuntime::LocalHandle local_handle)
: target_device(target_device), local_handle(local_handle) {}
+ FunctionData() : FunctionData("", -1) {}
};
const DeviceMgr* const device_mgr_;
@@ -158,8 +168,10 @@ class ProcessFunctionLibraryRuntime {
// Holds all the function invocations here.
std::unordered_map<string, FunctionLibraryRuntime::Handle> table_
GUARDED_BY(mu_);
- std::vector<FunctionData> function_data_ GUARDED_BY(mu_);
+ std::unordered_map<FunctionLibraryRuntime::Handle, FunctionData>
+ function_data_ GUARDED_BY(mu_);
std::unordered_map<Device*, std::unique_ptr<FunctionLibraryRuntime>> flr_map_;
+ int next_handle_ GUARDED_BY(mu_);
DistributedFunctionLibraryRuntime* const parent_;
};
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 6bc8f980c7..270e46dfe9 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -82,6 +82,22 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
EXPECT_GE(call_count, 1); // Test runner is used.
+ // Release the handle and then try running the function. It shouldn't
+ // succeed.
+ status = proc_flr_->ReleaseHandle(handle);
+ if (!status.ok()) {
+ return status;
+ }
+ Notification done2;
+ proc_flr_->Run(opts, handle, args, &out,
+ [&status, &done2](const Status& s) {
+ status = s;
+ done2.Notify();
+ });
+ done2.WaitForNotification();
+ EXPECT_TRUE(errors::IsNotFound(status));
+ EXPECT_TRUE(StringPiece(status.error_message()).contains("not found."));
+
return Status::OK();
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 305b140a44..d3d6358362 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -408,6 +408,9 @@ class FunctionLibraryRuntime {
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
Handle* handle) = 0;
+ // Releases state associated with the handle.
+ virtual Status ReleaseHandle(Handle handle) = 0;
+
// Returns the function body for the instantiated function given its
// handle 'h'. Returns nullptr if "h" is not found.
//