aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-05-09 09:42:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 10:52:05 -0700
commit16986a1c9ed64c2312ededf733f20a137b521819 (patch)
tree1b3f281941851a8c51cfe77a82df4fc574dfff88
parentbcec296af809947145a6ebfa1e46b1cafe21ec06 (diff)
[Functions] Fix unbounded memory growth in FunctionLibraryRuntime.
A recent change modified the behavior of `FunctionLibraryRuntimeImpl::ReleaseHandle()` so that it no longer freed the memory associated with an instantiated function. Since we rely on instantiating and releasing a potentially large number of instances of the same function in tf.data to isolate the (e.g. random number generator) state in each instance, this change meant that the memory consumption could grow without bound in a simple program like: ```python ds = tf.data.Dataset.from_tensors(0).repeat(None) # The function `lambda y: y + 1` would be instantiated for each element in the input. ds = ds.flat_map(lambda x: tf.data.Dataset.from_tensors(x).map( lambda y: y + tf.random_uniform([], minval=0, maxval=10, dtype=tf.int32))) iterator = ds.make_one_shot_iterator() next_elem = iterator.get_next() with tf.Session() as sess: while True: sess.run(next_elem) ``` PiperOrigin-RevId: 195983977
-rw-r--r--tensorflow/core/common_runtime/function.cc66
-rw-r--r--tensorflow/core/common_runtime/function_test.cc27
-rw-r--r--tensorflow/core/common_runtime/function_threadpool_test.cc14
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc17
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h12
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc10
6 files changed, 94 insertions, 52 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index bf05f6f1d9..d05564e9c4 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -208,19 +208,19 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
// The instantiated and transformed function is encoded as a Graph
// object, and an executor is created for the graph.
- struct Item : public core::RefCounted {
- bool invalidated = false;
+ struct Item {
+ uint64 instantiation_counter = 0;
const Graph* graph = nullptr; // Owned by exec.
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
- ~Item() override {
+ ~Item() {
delete this->func_graph;
delete this->exec;
}
};
- std::unordered_map<Handle, Item*> items_ GUARDED_BY(mu_);
+ std::unordered_map<Handle, std::unique_ptr<Item>> items_ GUARDED_BY(mu_);
ProcessFunctionLibraryRuntime* parent_ = nullptr; // not owned.
@@ -284,9 +284,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
}
}
-FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
- for (auto p : items_) p.second->Unref();
-}
+FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {}
// An asynchronous op kernel which executes an instantiated function
// defined in a library.
@@ -490,30 +488,24 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
options_copy.target = device_name_;
const string key = Canonicalize(function_name, attrs, options_copy);
- Handle found_handle = kInvalidHandle;
{
mutex_lock l(mu_);
- found_handle = parent_->GetHandle(key);
- if (found_handle != kInvalidHandle) {
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
FunctionLibraryRuntime::LocalHandle handle_on_device =
- parent_->GetHandleOnDevice(device_name_, found_handle);
+ parent_->GetHandleOnDevice(device_name_, *handle);
if (handle_on_device == kInvalidLocalHandle) {
return errors::Internal("LocalHandle not found for handle ", *handle,
".");
}
- auto iter = items_.find(handle_on_device);
- if (iter == items_.end()) {
+ auto item_handle = items_.find(handle_on_device);
+ if (item_handle == items_.end()) {
return errors::Internal("LocalHandle ", handle_on_device,
- " for handle ", found_handle,
+ " for handle ", *handle,
" not found in items.");
}
- Item* item = iter->second;
- if (!item->invalidated) {
- *handle = found_handle;
- return Status::OK();
- }
- // *item is invalidated. Fall through and instantiate the given
- // function_name/attrs/option again.
+ ++item_handle->second->instantiation_counter;
+ return Status::OK();
}
}
@@ -545,16 +537,18 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
{
mutex_lock l(mu_);
- Handle found_handle_again = parent_->GetHandle(key);
- if (found_handle_again != found_handle) {
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
delete fbody;
- *handle = found_handle_again;
+ ++items_[parent_->GetHandleOnDevice(device_name_, *handle)]
+ ->instantiation_counter;
} else {
*handle = parent_->AddHandle(key, device_name_, next_handle_);
Item* item = new Item;
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
- items_.insert({next_handle_, item});
+ item->instantiation_counter = 1;
+ items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
}
@@ -565,12 +559,17 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
return parent_->ReleaseHandle(handle);
}
+
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h));
- Item* item = items_[h];
- item->invalidated = true; // Reinstantiate later.
+ std::unique_ptr<Item>& item = items_[h];
+ --item->instantiation_counter;
+ if (item->instantiation_counter == 0) {
+ items_.erase(h);
+ TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
+ }
return Status::OK();
}
@@ -680,7 +679,7 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
- *item = items_[local_handle];
+ *item = items_[local_handle].get();
if ((*item)->exec != nullptr) {
return Status::OK();
}
@@ -731,7 +730,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
// computation is done and stored in *rets, we send the return values back
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
- item->Ref();
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", src_incarnation, args.size(),
device_context, {}, rendezvous, remote_args,
@@ -743,7 +741,6 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
s = frame->SetArgs(*remote_args);
}
if (!s.ok()) {
- item->Unref();
delete frame;
delete remote_args;
delete exec_args;
@@ -751,10 +748,9 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
item->exec->RunAsync(
- *exec_args, [item, frame, rets, done, source_device, target_device,
+ *exec_args, [frame, rets, done, source_device, target_device,
target_incarnation, rendezvous, device_context,
remote_args, exec_args](const Status& status) {
- core::ScopedUnref unref(item);
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -840,13 +836,11 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
return;
}
- item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
// Done callback.
- [item, frame, rets, done, exec_args](const Status& status) {
- core::ScopedUnref unref(item);
+ [frame, rets, done, exec_args](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -906,7 +900,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
exec_args->runner = *run_opts.runner;
exec_args->call_frame = frame;
- item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
@@ -915,7 +908,6 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
[item, frame, exec_args](DoneCallback done,
// Start unbound arguments.
const Status& status) {
- core::ScopedUnref unref(item);
delete exec_args;
done(status);
},
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 373fc64007..61b2f0e60f 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -231,8 +231,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
FunctionLibraryRuntime::Options opts;
- TF_RETURN_IF_ERROR(Run(flr, handle, opts, args, rets, add_runner));
- return flr->ReleaseHandle(handle);
+ status = Run(flr, handle, opts, args, rets, add_runner);
+ if (!status.ok()) return status;
+
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
@@ -293,8 +304,16 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
*rets[i] = retvals[i];
}
- // Release the handle.
- return flr->ReleaseHandle(handle);
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc
index 98dac38a8c..2d09e83d01 100644
--- a/tensorflow/core/common_runtime/function_threadpool_test.cc
+++ b/tensorflow/core/common_runtime/function_threadpool_test.cc
@@ -144,7 +144,19 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
FunctionLibraryRuntime::Options opts;
- return Run(flr, handle, opts, args, std::move(rets), add_runner);
+ status = Run(flr, handle, opts, args, rets, add_runner);
+ if (!status.ok()) return status;
+
+ // Release the handle and try running again. It should not succeed.
+ status = flr->ReleaseHandle(handle);
+ if (!status.ok()) return status;
+
+ Status status2 = Run(flr, handle, opts, args, std::move(rets));
+ EXPECT_TRUE(errors::IsInvalidArgument(status2));
+ EXPECT_TRUE(
+ str_util::StrContains(status2.error_message(), "remote execution."));
+
+ return status;
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 668ce87749..729312a310 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -183,8 +184,8 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
FunctionLibraryRuntime::LocalHandle local_handle) {
mutex_lock l(mu_);
auto h = next_handle_;
- FunctionData* fd = new FunctionData(device_name, local_handle);
- function_data_[h] = std::unique_ptr<FunctionData>(fd);
+ function_data_[h] = MakeUnique<FunctionData>(
+ device_name, local_handle, function_key);
table_[function_key] = h;
next_handle_++;
return h;
@@ -247,8 +248,8 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
gtl::FindWithDefault(table_, function_key, kInvalidHandle);
if (h == kInvalidHandle || function_data_.count(h) == 0) {
h = next_handle_;
- FunctionData* fd = new FunctionData(options.target, kInvalidHandle);
- function_data_[h] = std::unique_ptr<FunctionData>(fd);
+ function_data_[h] = MakeUnique<FunctionData>(
+ options.target, kInvalidHandle, function_key);
table_[function_key] = h;
next_handle_++;
}
@@ -263,6 +264,14 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
return Status::OK();
}
+Status ProcessFunctionLibraryRuntime::RemoveHandle(
+ FunctionLibraryRuntime::Handle handle) {
+ mutex_lock l(mu_);
+ table_.erase(function_data_[handle]->function_key());
+ function_data_.erase(handle);
+ return Status::OK();
+}
+
Status ProcessFunctionLibraryRuntime::ReleaseHandle(
FunctionLibraryRuntime::Handle handle) {
FunctionLibraryRuntime* flr = nullptr;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 05e5770899..69381dd34d 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -134,6 +134,9 @@ class ProcessFunctionLibraryRuntime {
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
+ // Removes handle from the state owned by this object.
+ Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
+
Status Clone(Env* env, int graph_def_version,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
@@ -147,10 +150,14 @@ class ProcessFunctionLibraryRuntime {
class FunctionData {
public:
FunctionData(const string& target_device,
- FunctionLibraryRuntime::LocalHandle local_handle)
- : target_device_(target_device), local_handle_(local_handle) {}
+ FunctionLibraryRuntime::LocalHandle local_handle,
+ const string& function_key)
+ : target_device_(target_device),
+ local_handle_(local_handle),
+ function_key_(function_key) {}
string target_device() { return target_device_; }
+ const string& function_key() { return function_key_; }
FunctionLibraryRuntime::LocalHandle local_handle() {
mutex_lock l(mu_);
@@ -169,6 +176,7 @@ class ProcessFunctionLibraryRuntime {
const string target_device_;
FunctionLibraryRuntime::LocalHandle local_handle_ GUARDED_BY(mu_);
+ const string function_key_;
bool init_started_ GUARDED_BY(mu_) = false;
Status init_result_ GUARDED_BY(mu_);
Notification init_done_;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index cc10e77ad2..4fbf2abc67 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -119,13 +119,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
EXPECT_GE(call_count, 1); // Test runner is used.
- // Release the handle and then try running the function. It
- // should still succeed.
+ // Release the handle and then try running the function. It shouldn't
+ // succeed.
status = proc_flr_->ReleaseHandle(handle);
if (!status.ok()) {
return status;
}
-
Notification done2;
proc_flr_->Run(opts, handle, args, &out,
[&status, &done2](const Status& s) {
@@ -133,7 +132,10 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
done2.Notify();
});
done2.WaitForNotification();
- return status;
+ EXPECT_TRUE(errors::IsNotFound(status));
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "not found."));
+
+ return Status::OK();
}
std::vector<Device*> devices_;