diff options
author | 2018-03-11 15:38:16 -0700 | |
---|---|---|
committer | 2018-03-13 07:58:18 -0700 | |
commit | fa5c66ba74a505e4a4b8472332918798bb17bb39 (patch) | |
tree | 82d04a1ab4e07f1d0c484c0482351e55a417bcbc | |
parent | 9af25bb2a76b0e5607acecaa93ae421352a70748 (diff) |
Fixes a race condition in function instantiation.
Previously, if the same function was being concurrently instantiated
and released:
1. Thread one could begin to instantiate the function, determine
that it already existed in the runtime, then be preempted.
2. Thread two could release the handle on the function, causing it to
be freed and removed from the `FunctionLibraryRuntime::items_` map.
3. Thread one could then incorrectly assume that the function still
existed, and fail to find it in the `FunctionLibraryRuntime::items_`
map, causing a segfault when it attempted to increment the refcount
on an uninitialized object.
PiperOrigin-RevId: 188661500
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 24 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/filter_dataset_op_test.py | 8 |
2 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 3e937ceb64..7174a876f6 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -479,11 +479,26 @@ Status FunctionLibraryRuntimeImpl::Instantiate( InstantiateOptions options_copy(options); options_copy.target = device_name_; const string key = Canonicalize(function_name, attrs, options_copy); - *handle = parent_->GetHandle(key); - if (*handle != kInvalidHandle) { + + { mutex_lock l(mu_); - items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref(); - return Status::OK(); + *handle = parent_->GetHandle(key); + if (*handle != kInvalidHandle) { + FunctionLibraryRuntime::LocalHandle handle_on_device = + parent_->GetHandleOnDevice(device_name_, *handle); + if (handle_on_device == kInvalidLocalHandle) { + return errors::Internal("LocalHandle not found for handle ", *handle, + "."); + } + auto item_handle = items_.find(handle_on_device); + if (item_handle == items_.end()) { + return errors::Internal("LocalHandle ", handle_on_device, + " for handle ", *handle, + " not found in items."); + } + item_handle->second->Ref(); + return Status::OK(); + } } Status s; @@ -536,6 +551,7 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { } LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); + CHECK_NE(h, kInvalidLocalHandle); mutex_lock l(mu_); CHECK_EQ(1, items_.count(h)); Item* item = items_[h]; diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 2c71723167..4f2216f0a3 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -176,6 +176,14 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testParallelFilters(self): + dataset = dataset_ops.Dataset.range(10).filter( + lambda x: math_ops.equal(x % 2, 0)) + iterators = [dataset.make_one_shot_iterator() for _ in range(10)] + next_elements = [iterator.get_next() for iterator in iterators] + with self.test_session() as sess: + self.assertEqual([0 for _ in range(10)], sess.run(next_elements)) + class FilterDatasetBenchmark(test.Benchmark): |