aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-08-10 17:05:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 17:09:12 -0700
commit12eb80cb9b4b51631a7cdfc9fce476a8b2ea225b (patch)
treeb438ba58a924721a8e84a159bbbee647ddfa9efa /tensorflow/contrib/data
parent6c08c6c22a7ccd3adad28fb76269122ab0a1fcaa (diff)
Speeding up MultiDeviceIterator by more efficient locking. We create a background thread that tries to keep a host side buffer for each device full. When a GetNext request comes in, we return from the buffer if available or else we schedule a callback to be called when the background thread eventually fetches an element for it.
PiperOrigin-RevId: 208292329
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc356
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py4
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py5
4 files changed, 271 insertions, 96 deletions
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 32f03ca683..13bcd77b4a 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -526,6 +526,15 @@ string SanitizeThreadSuffix(string suffix) {
return clean;
}
+struct HostBufferElement {
+ Status status;
+ bool end_of_sequence;
+ std::vector<Tensor> value;
+};
+
+using MultiDeviceIteratorCallback =
+ std::function<void(const HostBufferElement&)>;
+
class MultiDeviceIterator : public ResourceBase {
public:
MultiDeviceIterator(const DataTypeVector& output_types,
@@ -539,83 +548,45 @@ class MultiDeviceIterator : public ResourceBase {
devices_(devices),
flib_def_(std::move(flib_def)),
pflr_(std::move(pflr)),
- lib_(lib) {
- buffer_.resize(devices_.size());
- }
+ lib_(lib) {}
string DebugString() override {
- return strings::StrCat("MultiDeviceIterator");
+ return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
+ " devices");
}
- Status Init(std::unique_ptr<IteratorBase> iterator, int64* incarnation_id) {
- mutex_lock l(mu_);
+ Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
+ int64* incarnation_id) {
if (iterator) {
TF_RETURN_IF_ERROR(
VerifyTypesMatch(output_types_, iterator->output_dtypes()));
TF_RETURN_IF_ERROR(
VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
}
- host_iterator_.reset(iterator.release());
- incarnation_id_++;
+
+ mutex_lock l(mu_);
+ if (multi_device_buffer_) {
+ multi_device_buffer_->Reset();
+ }
+
+ ++incarnation_id_;
*incarnation_id = incarnation_id_;
- max_buffer_size_ = 0;
- num_elements_ = 0;
- buffer_.clear();
- buffer_.resize(devices_.size());
+
+ multi_device_buffer_.reset(
+ new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
+ std::move(iterator)));
return Status::OK();
}
- Status GetNextFromShard(IteratorContext* ctx, int shard_num,
- int64 incarnation_id,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) {
- // TODO(rohanj): This might potentially strand elements in other shards.
- // Opportunity to do smarter locking semantics.
- mutex_lock l(mu_);
- // Make sure we're in the right incarnation.
- if (incarnation_id != incarnation_id_) {
- return errors::InvalidArgument(
- "Current incarnation: ", incarnation_id_,
- "; Supplied incarnation: ", incarnation_id);
- }
- // Then look it up in the buffer.
- if (!buffer_[shard_num].empty()) {
- const HostBufferElement& elem = buffer_[shard_num].front();
- *out_tensors = elem.value;
- *end_of_sequence = elem.end_of_sequence;
- Status s = elem.status;
- buffer_[shard_num].pop_front();
- return s;
- }
- std::shared_ptr<IteratorBase> captured_iterator(host_iterator_);
- if (captured_iterator) {
- if (lib_ != nullptr) {
- ctx->set_lib(lib_);
- }
- while (true) {
- HostBufferElement elem;
- elem.status =
- captured_iterator->GetNext(ctx, &elem.value, &elem.end_of_sequence);
- int buffer_index = num_elements_ % devices_.size();
- num_elements_++;
- if (buffer_index == shard_num) {
- out_tensors->swap(elem.value);
- *end_of_sequence = elem.end_of_sequence;
- return elem.status;
- } else {
- buffer_[buffer_index].push_back(std::move(elem));
- // TODO(rohanj): Put an upper bound to buffer size.
- if (buffer_[buffer_index].size() > max_buffer_size_) {
- max_buffer_size_ = buffer_[buffer_index].size();
- VLOG(1) << "MultiDeviceIterator: Max buffer size increased to: "
- << max_buffer_size_;
- }
- }
- }
- } else {
- return errors::FailedPrecondition("Iterator not initialized");
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ if (lib_ != nullptr) {
+ ctx->set_lib(lib_);
}
- return Status::OK();
+ tf_shared_lock l(mu_);
+ multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
+ std::move(callback));
}
const DataTypeVector& output_types() const { return output_types_; }
@@ -630,25 +601,218 @@ class MultiDeviceIterator : public ResourceBase {
}
private:
- struct HostBufferElement {
- Status status;
- bool end_of_sequence;
- std::vector<Tensor> value;
+ // A private class that uses a background thread to keep a per device buffer
+ // full.
+ class MultiDeviceBuffer {
+ public:
+ MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
+ std::unique_ptr<IteratorBase> host_iterator)
+ : buffer_(size),
+ size_(size),
+ max_buffer_size_(max_buffer_size),
+ incarnation_id_(incarnation_id),
+ host_iterator_(std::move(host_iterator)) {}
+
+ ~MultiDeviceBuffer() { Reset(); }
+
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (background_thread_finished_) {
+ return;
+ }
+
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
+
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
+ }
+ RunPendingCallbacks();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ HostBufferElement elem;
+ if (incarnation_id_ != incarnation_id) {
+ elem.status = errors::InvalidArgument("Invalid incarnation id");
+ callback(elem);
+ return;
+ }
+
+ bool produced_output = false;
+ {
+ mutex_lock l(mu_);
+ if (cancelled_) {
+ elem.status = errors::Cancelled("Cancelled Multidevice iterator");
+ callback(elem);
+ return;
+ }
+
+ EnsureBackgroundThreadStarted(ctx);
+
+ if (!buffer_[shard_num].data.empty()) {
+ produced_output = true;
+ std::swap(elem, buffer_[shard_num].data.front());
+ buffer_[shard_num].data.pop_front();
+ // Wake up background thread if it is blocked on this element.
+ if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
+ buffer_[shard_num].cond_var.notify_all();
+ }
+ } else {
+ if (background_thread_finished_) {
+ produced_output = true;
+ elem.end_of_sequence = true;
+ } else {
+ buffer_[shard_num].callbacks.push_back(std::move(callback));
+ callback = nullptr;
+ }
+ }
+ }
+
+ if (produced_output) {
+ callback(elem);
+ }
+ }
+
+ private:
+ void EnsureBackgroundThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!background_thread_) {
+ background_thread_.reset(ctx->env()->StartThread(
+ {}, "multi_device_iterator_background_thread",
+ std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
+ this, new IteratorContext(*ctx))));
+ }
+ }
+
+ void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
+ // Run all remaining callbacks.
+ std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
+ std::vector<HostBufferElement> cancellation_elements;
+ {
+ mutex_lock l(mu_);
+
+ for (int i = 0; i < size_; ++i) {
+ while (!buffer_[i].callbacks.empty()) {
+ if (buffer_[i].data.empty()) {
+ HostBufferElement elem;
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ cancellation_elements.push_back(std::move(elem));
+ } else {
+ cancellation_elements.push_back(
+ std::move(buffer_[i].data.front()));
+ buffer_[i].data.pop_front();
+ }
+ cancellation_callbacks.push_back(
+ std::move(buffer_[i].callbacks.front()));
+ buffer_[i].callbacks.pop_front();
+ }
+ }
+ }
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_elements[i]);
+ }
+ }
+
+ void BackgroundThread(IteratorContext* ctx) {
+ std::unique_ptr<IteratorContext> cleanup(ctx);
+ int shard_to_fetch = 0;
+ while (true) {
+ HostBufferElement elem;
+ MultiDeviceIteratorCallback callback = nullptr;
+ bool end_of_iterator = false;
+
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
+ buffer_[shard_to_fetch].cond_var.wait(l);
+ }
+
+ if (cancelled_) {
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ return;
+ }
+ }
+
+ elem.status =
+ host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
+
+ if (elem.status.ok() && elem.end_of_sequence) {
+ end_of_iterator = true;
+ }
+
+ {
+ mutex_lock l(mu_);
+ // Try to find a callback, else just push stuff into buffer.
+ if (!buffer_[shard_to_fetch].callbacks.empty()) {
+ callback = buffer_[shard_to_fetch].callbacks.front();
+ buffer_[shard_to_fetch].callbacks.pop_front();
+ } else {
+ buffer_[shard_to_fetch].data.push_back(std::move(elem));
+ elem = HostBufferElement();
+ }
+ }
+
+ if (callback) {
+ (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
+ }
+
+ // Finish off the thread if we reach the end of the iterator. Runs
+ // pending callbacks.
+ if (end_of_iterator) {
+ {
+ mutex_lock l(mu_);
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ }
+ RunPendingCallbacks();
+ return;
+ }
+ shard_to_fetch = (shard_to_fetch + 1) % size_;
+ }
+ }
+
+ struct HostBuffer {
+ condition_variable cond_var;
+ std::deque<HostBufferElement> data;
+ std::deque<MultiDeviceIteratorCallback> callbacks;
+ };
+
+ mutex mu_;
+ std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
+ bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
+
+ std::vector<HostBuffer> buffer_;
+
+ const size_t size_;
+ const int64 max_buffer_size_;
+ const int64 incarnation_id_;
+ const std::unique_ptr<IteratorBase> host_iterator_;
};
mutex mu_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
const std::vector<string> devices_;
- int64 num_elements_ GUARDED_BY(mu_) = 0;
- int64 max_buffer_size_ GUARDED_BY(mu_) = 0;
- int64 incarnation_id_ GUARDED_BY(mu_) = 0;
- std::vector<std::deque<HostBufferElement>> buffer_ GUARDED_BY(mu_);
- std::unique_ptr<FunctionLibraryDefinition> flib_def_;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- FunctionLibraryRuntime* lib_ = nullptr; // not owned.
- std::shared_ptr<IteratorBase> host_iterator_;
+ const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
+
+ int64 incarnation_id_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
};
// Just creates a MultiDeviceIterator and returns it.
@@ -754,6 +918,10 @@ class MultiDeviceIteratorInitOp : public OpKernel {
: OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
+ const Tensor* tensor_max_buffer_size;
+ OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
+ int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
+
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
MultiDeviceIterator* resource;
@@ -766,7 +934,8 @@ class MultiDeviceIteratorInitOp : public OpKernel {
OP_REQUIRES_OK(ctx,
dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
int64 incarnation_id;
- OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), &incarnation_id));
+ OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
+ &incarnation_id));
Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
tensor_incarnation_id.scalar<int64>()() = incarnation_id;
OP_REQUIRES_OK(ctx,
@@ -804,9 +973,6 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
thread_pool_->Schedule(std::bind(
[ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
- std::vector<Tensor> components;
- bool end_of_sequence = false;
-
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
@@ -817,22 +983,26 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
};
IteratorContext iter_ctx(std::move(params));
- Status s =
- iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
- &components, &end_of_sequence);
- iterator->Unref();
+ MultiDeviceIteratorCallback callback = std::bind(
+ [ctx](const HostBufferElement& elem, DoneCallback done) {
+ // iterator->Unref();
+ Status s = elem.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (elem.end_of_sequence) {
+ ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ } else {
+ for (int i = 0; i < elem.value.size(); ++i) {
+ ctx->set_output(i, elem.value[i]);
+ }
+ }
+ done();
+ },
+ std::placeholders::_1, std::move(done));
- if (!s.ok()) {
- ctx->SetStatus(s);
- } else if (end_of_sequence) {
- ctx->SetStatus(errors::OutOfRange("End of sequence"));
- } else {
- for (int i = 0; i < components.size(); ++i) {
- // TODO(mrry): Check that the shapes match the shape attrs.
- ctx->set_output(i, components[i]);
- }
- }
- done();
+ iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
+ callback);
+ iterator->Unref();
},
std::move(done)));
}
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index 66a7c7fdcd..cc5e250ea1 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -168,9 +168,11 @@ output_shapes: The list of shapes being produced.
REGISTER_OP("MultiDeviceIteratorInit")
.Input("dataset: variant")
.Input("multi_device_iterator: resource")
+ .Input("max_buffer_size: int64")
.Output("incarnation_id: int64")
.Doc(R"doc(
Initializes the multi device iterator with the given dataset.
+max_buffer_size: The maximum size of the host side per device buffer to keep.
incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator
is running.
dataset: Dataset to be iterated upon.
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index d66305d732..361fe0dd39 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -1021,7 +1021,7 @@ class MultiDeviceIteratorTest(test.TestCase):
def testUneven(self):
dataset = dataset_ops.Dataset.range(10)
multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
+ dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
config = config_pb2.ConfigProto(device_count={"CPU": 3})
@@ -1079,7 +1079,7 @@ class MultiDeviceIteratorTest(test.TestCase):
with compat.forward_compatibility_horizon(2018, 8, 4):
dataset = dataset_ops.Dataset.range(10)
multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/gpu:0"])
+ dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index be6fb69fee..5222011d04 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -631,6 +631,7 @@ class MultiDeviceIterator(object):
def __init__(self,
dataset,
devices,
+ max_buffer_size=1,
prefetch_buffer_size=1,
source_device="/cpu:0"):
"""Constructs a MultiDeviceIterator.
@@ -638,6 +639,7 @@ class MultiDeviceIterator(object):
Args:
dataset: The input dataset to be iterated over.
devices: The list of devices to fetch data to.
+ max_buffer_size: Maximum size of the host side per device buffer to keep.
prefetch_buffer_size: if > 1, then we setup a buffer on each device
to prefetch into.
source_device: The host device to place the `dataset` on.
@@ -668,7 +670,8 @@ class MultiDeviceIterator(object):
# iterators and the multi-device iterator.
self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
self._dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._multi_device_iterator_resource)
+ self._multi_device_iterator_resource,
+ max_buffer_size=max_buffer_size)
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
# initialize the device side of the pipeline. This would allow the