aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-03-11 15:38:16 -0700
committerGravatar Derek Murray <mrry@google.com>2018-03-13 07:58:18 -0700
commitfa5c66ba74a505e4a4b8472332918798bb17bb39 (patch)
tree82d04a1ab4e07f1d0c484c0482351e55a417bcbc
parent9af25bb2a76b0e5607acecaa93ae421352a70748 (diff)
Fixes a race condition in function instantiation.
Previously, if the same function was being concurrently instantiated and released: 1. Thread one could begin to instantiate the function, determine that it already existed in the runtime, then be preempted. 2. Thread two could release the handle on the function, causing it to be freed and removed from the `FunctionLibraryRuntime::items_` map. 3. Thread one could then incorrectly assume that the function still existed, and fail to find it in the `FunctionLibraryRuntime::items_` map, causing a segfault when it attempted to increment the refcount on an uninitialized object. PiperOrigin-RevId: 188661500
-rw-r--r--tensorflow/core/common_runtime/function.cc24
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py8
2 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 3e937ceb64..7174a876f6 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -479,11 +479,26 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
InstantiateOptions options_copy(options);
options_copy.target = device_name_;
const string key = Canonicalize(function_name, attrs, options_copy);
- *handle = parent_->GetHandle(key);
- if (*handle != kInvalidHandle) {
+
+ {
mutex_lock l(mu_);
- items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
- return Status::OK();
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
+ FunctionLibraryRuntime::LocalHandle handle_on_device =
+ parent_->GetHandleOnDevice(device_name_, *handle);
+ if (handle_on_device == kInvalidLocalHandle) {
+ return errors::Internal("LocalHandle not found for handle ", *handle,
+ ".");
+ }
+ auto item_handle = items_.find(handle_on_device);
+ if (item_handle == items_.end()) {
+ return errors::Internal("LocalHandle ", handle_on_device,
+ " for handle ", *handle,
+ " not found in items.");
+ }
+ item_handle->second->Ref();
+ return Status::OK();
+ }
}
Status s;
@@ -536,6 +551,7 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
}
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
+ CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h));
Item* item = items_[h];
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 2c71723167..4f2216f0a3 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -176,6 +176,14 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testParallelFilters(self):
+ dataset = dataset_ops.Dataset.range(10).filter(
+ lambda x: math_ops.equal(x % 2, 0))
+ iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
+ next_elements = [iterator.get_next() for iterator in iterators]
+ with self.test_session() as sess:
+ self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
+
class FilterDatasetBenchmark(test.Benchmark):