diff options
author | 2018-08-10 17:05:25 -0700 | |
---|---|---|
committer | 2018-08-10 17:09:12 -0700 | |
commit | 12eb80cb9b4b51631a7cdfc9fce476a8b2ea225b (patch) | |
tree | b438ba58a924721a8e84a159bbbee647ddfa9efa /tensorflow/contrib/data | |
parent | 6c08c6c22a7ccd3adad28fb76269122ab0a1fcaa (diff) |
Speeding up MultiDeviceIterator by more efficient locking. We create a background thread that tries to keep a host side buffer for each device full. When a GetNext request comes in, we return from the buffer if available or else we schedule a callback to be called when the background thread eventually fetches an element for it.
PiperOrigin-RevId: 208292329
Diffstat (limited to 'tensorflow/contrib/data')
4 files changed, 271 insertions, 96 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 32f03ca683..13bcd77b4a 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -526,6 +526,15 @@ string SanitizeThreadSuffix(string suffix) { return clean; } +struct HostBufferElement { + Status status; + bool end_of_sequence; + std::vector<Tensor> value; +}; + +using MultiDeviceIteratorCallback = + std::function<void(const HostBufferElement&)>; + class MultiDeviceIterator : public ResourceBase { public: MultiDeviceIterator(const DataTypeVector& output_types, @@ -539,83 +548,45 @@ class MultiDeviceIterator : public ResourceBase { devices_(devices), flib_def_(std::move(flib_def)), pflr_(std::move(pflr)), - lib_(lib) { - buffer_.resize(devices_.size()); - } + lib_(lib) {} string DebugString() override { - return strings::StrCat("MultiDeviceIterator"); + return strings::StrCat("MultiDeviceIterator for ", devices_.size(), + " devices"); } - Status Init(std::unique_ptr<IteratorBase> iterator, int64* incarnation_id) { - mutex_lock l(mu_); + Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size, + int64* incarnation_id) { if (iterator) { TF_RETURN_IF_ERROR( VerifyTypesMatch(output_types_, iterator->output_dtypes())); TF_RETURN_IF_ERROR( VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); } - host_iterator_.reset(iterator.release()); - incarnation_id_++; + + mutex_lock l(mu_); + if (multi_device_buffer_) { + multi_device_buffer_->Reset(); + } + + ++incarnation_id_; *incarnation_id = incarnation_id_; - max_buffer_size_ = 0; - num_elements_ = 0; - buffer_.clear(); - buffer_.resize(devices_.size()); + + multi_device_buffer_.reset( + new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_, + std::move(iterator))); return Status::OK(); } - Status GetNextFromShard(IteratorContext* ctx, int shard_num, - int64 incarnation_id, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) { - // TODO(rohanj): This might potentially strand elements in other shards. - // Opportunity to do smarter locking semantics. - mutex_lock l(mu_); - // Make sure we're in the right incarnation. - if (incarnation_id != incarnation_id_) { - return errors::InvalidArgument( - "Current incarnation: ", incarnation_id_, - "; Supplied incarnation: ", incarnation_id); - } - // Then look it up in the buffer. - if (!buffer_[shard_num].empty()) { - const HostBufferElement& elem = buffer_[shard_num].front(); - *out_tensors = elem.value; - *end_of_sequence = elem.end_of_sequence; - Status s = elem.status; - buffer_[shard_num].pop_front(); - return s; - } - std::shared_ptr<IteratorBase> captured_iterator(host_iterator_); - if (captured_iterator) { - if (lib_ != nullptr) { - ctx->set_lib(lib_); - } - while (true) { - HostBufferElement elem; - elem.status = - captured_iterator->GetNext(ctx, &elem.value, &elem.end_of_sequence); - int buffer_index = num_elements_ % devices_.size(); - num_elements_++; - if (buffer_index == shard_num) { - out_tensors->swap(elem.value); - *end_of_sequence = elem.end_of_sequence; - return elem.status; - } else { - buffer_[buffer_index].push_back(std::move(elem)); - // TODO(rohanj): Put an upper bound to buffer size. - if (buffer_[buffer_index].size() > max_buffer_size_) { - max_buffer_size_ = buffer_[buffer_index].size(); - VLOG(1) << "MultiDeviceIterator: Max buffer size increased to: " - << max_buffer_size_; - } - } - } - } else { - return errors::FailedPrecondition("Iterator not initialized"); + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + if (lib_ != nullptr) { + ctx->set_lib(lib_); } - return Status::OK(); + tf_shared_lock l(mu_); + multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, + std::move(callback)); } const DataTypeVector& output_types() const { return output_types_; } @@ -630,25 +601,218 @@ class MultiDeviceIterator : public ResourceBase { } private: - struct HostBufferElement { - Status status; - bool end_of_sequence; - std::vector<Tensor> value; + // A private class that uses a background thread to keep a per device buffer + // full. + class MultiDeviceBuffer { + public: + MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id, + std::unique_ptr<IteratorBase> host_iterator) + : buffer_(size), + size_(size), + max_buffer_size_(max_buffer_size), + incarnation_id_(incarnation_id), + host_iterator_(std::move(host_iterator)) {} + + ~MultiDeviceBuffer() { Reset(); } + + void Reset() LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (background_thread_finished_) { + return; + } + + cancelled_ = true; + // Wake up the background thread. + for (int i = 0; i < size_; ++i) { + buffer_[i].cond_var.notify_all(); + } + + // Make sure background thread has finished first. + while (!background_thread_finished_) { + shutdown_cond_var_.wait(l); + } + } + RunPendingCallbacks(); + } + + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + HostBufferElement elem; + if (incarnation_id_ != incarnation_id) { + elem.status = errors::InvalidArgument("Invalid incarnation id"); + callback(elem); + return; + } + + bool produced_output = false; + { + mutex_lock l(mu_); + if (cancelled_) { + elem.status = errors::Cancelled("Cancelled Multidevice iterator"); + callback(elem); + return; + } + + EnsureBackgroundThreadStarted(ctx); + + if (!buffer_[shard_num].data.empty()) { + produced_output = true; + std::swap(elem, buffer_[shard_num].data.front()); + buffer_[shard_num].data.pop_front(); + // Wake up background thread if it is blocked on this element. + if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) { + buffer_[shard_num].cond_var.notify_all(); + } + } else { + if (background_thread_finished_) { + produced_output = true; + elem.end_of_sequence = true; + } else { + buffer_[shard_num].callbacks.push_back(std::move(callback)); + callback = nullptr; + } + } + } + + if (produced_output) { + callback(elem); + } + } + + private: + void EnsureBackgroundThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!background_thread_) { + background_thread_.reset(ctx->env()->StartThread( + {}, "multi_device_iterator_background_thread", + std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, + this, new IteratorContext(*ctx)))); + } + } + + void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) { + // Run all remaining callbacks. + std::vector<MultiDeviceIteratorCallback> cancellation_callbacks; + std::vector<HostBufferElement> cancellation_elements; + { + mutex_lock l(mu_); + + for (int i = 0; i < size_; ++i) { + while (!buffer_[i].callbacks.empty()) { + if (buffer_[i].data.empty()) { + HostBufferElement elem; + elem.status = + errors::Cancelled("Cancelled and buffer not filled."); + cancellation_elements.push_back(std::move(elem)); + } else { + cancellation_elements.push_back( + std::move(buffer_[i].data.front())); + buffer_[i].data.pop_front(); + } + cancellation_callbacks.push_back( + std::move(buffer_[i].callbacks.front())); + buffer_[i].callbacks.pop_front(); + } + } + } + for (int i = 0; i < cancellation_callbacks.size(); ++i) { + cancellation_callbacks[i](cancellation_elements[i]); + } + } + + void BackgroundThread(IteratorContext* ctx) { + std::unique_ptr<IteratorContext> cleanup(ctx); + int shard_to_fetch = 0; + while (true) { + HostBufferElement elem; + MultiDeviceIteratorCallback callback = nullptr; + bool end_of_iterator = false; + + { + mutex_lock l(mu_); + while (!cancelled_ && + buffer_[shard_to_fetch].data.size() >= max_buffer_size_) { + buffer_[shard_to_fetch].cond_var.wait(l); + } + + if (cancelled_) { + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + return; + } + } + + elem.status = + host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence); + + if (elem.status.ok() && elem.end_of_sequence) { + end_of_iterator = true; + } + + { + mutex_lock l(mu_); + // Try to find a callback, else just push stuff into buffer. + if (!buffer_[shard_to_fetch].callbacks.empty()) { + callback = buffer_[shard_to_fetch].callbacks.front(); + buffer_[shard_to_fetch].callbacks.pop_front(); + } else { + buffer_[shard_to_fetch].data.push_back(std::move(elem)); + elem = HostBufferElement(); + } + } + + if (callback) { + (*ctx->runner())(std::bind(std::move(callback), std::move(elem))); + } + + // Finish off the thread if we reach the end of the iterator. Runs + // pending callbacks. + if (end_of_iterator) { + { + mutex_lock l(mu_); + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + } + RunPendingCallbacks(); + return; + } + shard_to_fetch = (shard_to_fetch + 1) % size_; + } + } + + struct HostBuffer { + condition_variable cond_var; + std::deque<HostBufferElement> data; + std::deque<MultiDeviceIteratorCallback> callbacks; + }; + + mutex mu_; + std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_); + bool background_thread_finished_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(mu_) = false; + condition_variable shutdown_cond_var_ GUARDED_BY(mu_); + + std::vector<HostBuffer> buffer_; + + const size_t size_; + const int64 max_buffer_size_; + const int64 incarnation_id_; + const std::unique_ptr<IteratorBase> host_iterator_; }; mutex mu_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; const std::vector<string> devices_; - int64 num_elements_ GUARDED_BY(mu_) = 0; - int64 max_buffer_size_ GUARDED_BY(mu_) = 0; - int64 incarnation_id_ GUARDED_BY(mu_) = 0; - std::vector<std::deque<HostBufferElement>> buffer_ GUARDED_BY(mu_); - std::unique_ptr<FunctionLibraryDefinition> flib_def_; - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; - FunctionLibraryRuntime* lib_ = nullptr; // not owned. - std::shared_ptr<IteratorBase> host_iterator_; + const std::unique_ptr<FunctionLibraryDefinition> flib_def_; + const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + FunctionLibraryRuntime* const lib_ = nullptr; // not owned. std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_); + + int64 incarnation_id_ GUARDED_BY(mu_) = 0; + std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_); }; // Just creates a MultiDeviceIterator and returns it. @@ -754,6 +918,10 @@ class MultiDeviceIteratorInitOp : public OpKernel { : OpKernel(ctx) {} void Compute(OpKernelContext* ctx) override { + const Tensor* tensor_max_buffer_size; + OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size)); + int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()(); + DatasetBase* dataset; OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); MultiDeviceIterator* resource; @@ -766,7 +934,8 @@ class MultiDeviceIteratorInitOp : public OpKernel { OP_REQUIRES_OK(ctx, dataset->MakeIterator(&iter_ctx, "Iterator", &iterator)); int64 incarnation_id; - OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), &incarnation_id)); + OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, + &incarnation_id)); Tensor tensor_incarnation_id(DT_INT64, TensorShape({})); tensor_incarnation_id.scalar<int64>()() = incarnation_id; OP_REQUIRES_OK(ctx, @@ -804,9 +973,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); thread_pool_->Schedule(std::bind( [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { - std::vector<Tensor> components; - bool end_of_sequence = false; - IteratorContext::Params params; params.env = ctx->env(); params.runner = *(ctx->runner()); @@ -817,22 +983,26 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { }; IteratorContext iter_ctx(std::move(params)); - Status s = - iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, - &components, &end_of_sequence); - iterator->Unref(); + MultiDeviceIteratorCallback callback = std::bind( + [ctx](const HostBufferElement& elem, DoneCallback done) { + // iterator->Unref(); + Status s = elem.status; + if (!s.ok()) { + ctx->SetStatus(s); + } else if (elem.end_of_sequence) { + ctx->SetStatus(errors::OutOfRange("End of sequence")); + } else { + for (int i = 0; i < elem.value.size(); ++i) { + ctx->set_output(i, elem.value[i]); + } + } + done(); + }, + std::placeholders::_1, std::move(done)); - if (!s.ok()) { - ctx->SetStatus(s); - } else if (end_of_sequence) { - ctx->SetStatus(errors::OutOfRange("End of sequence")); - } else { - for (int i = 0; i < components.size(); ++i) { - // TODO(mrry): Check that the shapes match the shape attrs. - ctx->set_output(i, components[i]); - } - } - done(); + iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, + callback); + iterator->Unref(); }, std::move(done))); } diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 66a7c7fdcd..cc5e250ea1 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -168,9 +168,11 @@ output_shapes: The list of shapes being produced. REGISTER_OP("MultiDeviceIteratorInit") .Input("dataset: variant") .Input("multi_device_iterator: resource") + .Input("max_buffer_size: int64") .Output("incarnation_id: int64") .Doc(R"doc( Initializes the multi device iterator with the given dataset. +max_buffer_size: The maximum size of the host side per device buffer to keep. incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator is running. dataset: Dataset to be iterated upon. diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index d66305d732..361fe0dd39 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -1021,7 +1021,7 @@ class MultiDeviceIteratorTest(test.TestCase): def testUneven(self): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/cpu:2"]) + dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() config = config_pb2.ConfigProto(device_count={"CPU": 3}) @@ -1079,7 +1079,7 @@ class MultiDeviceIteratorTest(test.TestCase): with compat.forward_compatibility_horizon(2018, 8, 4): dataset = dataset_ops.Dataset.range(10) multi_device_iterator = prefetching_ops.MultiDeviceIterator( - dataset, ["/cpu:1", "/gpu:0"]) + dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4) elem_on_1, elem_on_2 = multi_device_iterator.get_next() config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1}) diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index be6fb69fee..5222011d04 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -631,6 +631,7 @@ class MultiDeviceIterator(object): def __init__(self, dataset, devices, + max_buffer_size=1, prefetch_buffer_size=1, source_device="/cpu:0"): """Constructs a MultiDeviceIterator. @@ -638,6 +639,7 @@ class MultiDeviceIterator(object): Args: dataset: The input dataset to be iterated over. devices: The list of devices to fetch data to. + max_buffer_size: Maximum size of the host side per device buffer to keep. prefetch_buffer_size: if > 1, then we setup a buffer on each device to prefetch into. source_device: The host device to place the `dataset` on. @@ -668,7 +670,8 @@ class MultiDeviceIterator(object): # iterators and the multi-device iterator. self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( self._dataset._as_variant_tensor(), # pylint: disable=protected-access - self._multi_device_iterator_resource) + self._multi_device_iterator_resource, + max_buffer_size=max_buffer_size) # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to # initialize the device side of the pipeline. This would allow the |