diff options
author | 2018-03-14 22:04:24 -0700 | |
---|---|---|
committer | 2018-03-14 22:04:24 -0700 | |
commit | e6c83df4ac74a34b3a9073bfe1efe206928bbd5f (patch) | |
tree | 439a538681c1a4238a6222fbbef3f5faf44cb5b4 | |
parent | 8198a84e1a584cd3d14acd0bd52e04cf2d66f341 (diff) | |
parent | a6f8b220638484c0b6e54f3a7d445c155f578535 (diff) |
Merge branch 'r1.7' into update_readme
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 24 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/filter_dataset_op_test.py | 8 |
2 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 3e937ceb64..7174a876f6 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -479,11 +479,26 @@ Status FunctionLibraryRuntimeImpl::Instantiate( InstantiateOptions options_copy(options); options_copy.target = device_name_; const string key = Canonicalize(function_name, attrs, options_copy); - *handle = parent_->GetHandle(key); - if (*handle != kInvalidHandle) { + + { mutex_lock l(mu_); - items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref(); - return Status::OK(); + *handle = parent_->GetHandle(key); + if (*handle != kInvalidHandle) { + FunctionLibraryRuntime::LocalHandle handle_on_device = + parent_->GetHandleOnDevice(device_name_, *handle); + if (handle_on_device == kInvalidLocalHandle) { + return errors::Internal("LocalHandle not found for handle ", *handle, + "."); + } + auto item_handle = items_.find(handle_on_device); + if (item_handle == items_.end()) { + return errors::Internal("LocalHandle ", handle_on_device, + " for handle ", *handle, + " not found in items."); + } + item_handle->second->Ref(); + return Status::OK(); + } } Status s; @@ -536,6 +551,7 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) { } LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle); + CHECK_NE(h, kInvalidLocalHandle); mutex_lock l(mu_); CHECK_EQ(1, items_.count(h)); Item* item = items_[h]; diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py index 2c71723167..4f2216f0a3 100644 --- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py @@ -176,6 +176,14 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testParallelFilters(self): + dataset = dataset_ops.Dataset.range(10).filter( + lambda x: math_ops.equal(x % 2, 0)) + iterators = [dataset.make_one_shot_iterator() for _ in range(10)] + next_elements = [iterator.get_next() for iterator in iterators] + with self.test_session() as sess: + self.assertEqual([0 for _ in range(10)], sess.run(next_elements)) + class FilterDatasetBenchmark(test.Benchmark): |