aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Gunhan Gulsoy <gunan@google.com>2018-03-14 22:04:24 -0700
committerGravatar GitHub <noreply@github.com>2018-03-14 22:04:24 -0700
commite6c83df4ac74a34b3a9073bfe1efe206928bbd5f (patch)
tree439a538681c1a4238a6222fbbef3f5faf44cb5b4
parent8198a84e1a584cd3d14acd0bd52e04cf2d66f341 (diff)
parenta6f8b220638484c0b6e54f3a7d445c155f578535 (diff)
Merge branch 'r1.7' into update_readme
-rw-r--r--tensorflow/core/common_runtime/function.cc24
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py8
2 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 3e937ceb64..7174a876f6 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -479,11 +479,26 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
InstantiateOptions options_copy(options);
options_copy.target = device_name_;
const string key = Canonicalize(function_name, attrs, options_copy);
- *handle = parent_->GetHandle(key);
- if (*handle != kInvalidHandle) {
+
+ {
mutex_lock l(mu_);
- items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
- return Status::OK();
+ *handle = parent_->GetHandle(key);
+ if (*handle != kInvalidHandle) {
+ FunctionLibraryRuntime::LocalHandle handle_on_device =
+ parent_->GetHandleOnDevice(device_name_, *handle);
+ if (handle_on_device == kInvalidLocalHandle) {
+ return errors::Internal("LocalHandle not found for handle ", *handle,
+ ".");
+ }
+ auto item_handle = items_.find(handle_on_device);
+ if (item_handle == items_.end()) {
+ return errors::Internal("LocalHandle ", handle_on_device,
+ " for handle ", *handle,
+ " not found in items.");
+ }
+ item_handle->second->Ref();
+ return Status::OK();
+ }
}
Status s;
@@ -536,6 +551,7 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
}
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
+ CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h));
Item* item = items_[h];
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 2c71723167..4f2216f0a3 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -176,6 +176,14 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testParallelFilters(self):
+ dataset = dataset_ops.Dataset.range(10).filter(
+ lambda x: math_ops.equal(x % 2, 0))
+ iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
+ next_elements = [iterator.get_next() for iterator in iterators]
+ with self.test_session() as sess:
+ self.assertEqual([0 for _ in range(10)], sess.run(next_elements))
+
class FilterDatasetBenchmark(test.Benchmark):