diff options
author | Vijay Vasudevan <vrv@google.com> | 2015-11-12 11:27:00 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2015-11-12 11:27:00 -0800 |
commit | 4dffee7f62d81ec9173aba1b0ef6b96e47f8037c (patch) | |
tree | 4b5c04b37afe45fdc5f1729252514a2770fbf1ab /tensorflow/core/kernels/queue_base.cc | |
parent | f2102f4e2c1c87f1d1bf9ab856a2849c54478760 (diff) |
TensorFlow: Minor updates to docs, BUILD, GPU config / perf, etc.
Changes:
- Updates to op documentation and index by Josh
- More changes to BUILD files for python 3 support by @girving
- Fix to Eigen to use DenseIndex everywhere by @jiayq
- Enable configuration for cuda compute capability by @zheng-xq,
including updates to docs.
- Route aggregation method through optimizer by schuster
- Updates to install instructions for bazel 0.1.1.
Base CL: 107702099
Diffstat (limited to 'tensorflow/core/kernels/queue_base.cc')
-rw-r--r-- | tensorflow/core/kernels/queue_base.cc | 265 |
1 files changed, 222 insertions, 43 deletions
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index 4217b9ce86..d0f47505c4 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -46,52 +46,14 @@ Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) { } // namespace -// static -Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, - int index) { -#define HANDLE_TYPE(DT) \ - if (parent.dtype() == DT) { \ - TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \ - return Status::OK(); \ - } - HANDLE_TYPE(DT_FLOAT); - HANDLE_TYPE(DT_DOUBLE); - HANDLE_TYPE(DT_INT32); - HANDLE_TYPE(DT_UINT8); - HANDLE_TYPE(DT_INT16); - HANDLE_TYPE(DT_INT8); - HANDLE_TYPE(DT_STRING); - HANDLE_TYPE(DT_INT64); -#undef HANDLE_TYPE - return errors::Unimplemented("Unhandled data type: ", parent.dtype()); -} - -// static -Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, - int index) { -#define HANDLE_TYPE(DT) \ - if (element.dtype() == DT) { \ - TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \ - return Status::OK(); \ - } - HANDLE_TYPE(DT_FLOAT); - HANDLE_TYPE(DT_DOUBLE); - HANDLE_TYPE(DT_INT32); - HANDLE_TYPE(DT_UINT8); - HANDLE_TYPE(DT_INT16); - HANDLE_TYPE(DT_INT8); - HANDLE_TYPE(DT_STRING); - HANDLE_TYPE(DT_INT64); -#undef HANDLE_TYPE - return errors::Unimplemented("Unhandled data type: ", element.dtype()); -} - -QueueBase::QueueBase(const DataTypeVector& component_dtypes, +QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes, const std::vector<TensorShape>& component_shapes, const string& name) - : component_dtypes_(component_dtypes), + : capacity_(capacity), + component_dtypes_(component_dtypes), component_shapes_(component_shapes), - name_(name) {} + name_(name), + closed_(false) {} Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const { if (tuple.size() != static_cast<size_t>(num_components())) { @@ -172,4 +134,221 @@ Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { return Status::OK(); } +// TODO(mrry): If these checks become a bottleneck, find a way to +// reduce the number of times that they are called. +Status QueueBase::ValidateTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + if (specified_shapes()) { + for (size_t i = 0; i < tuple.size(); ++i) { + if (!tuple[i].shape().IsSameSize(component_shapes_[i])) { + return errors::InvalidArgument( + "Shape mismatch in tuple component ", i, ". Expected ", + component_shapes_[i].ShortDebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + } + return Status::OK(); +} + +// TODO(mrry): If these checks become a bottleneck, find a way to +// reduce the number of times that they are called. +Status QueueBase::ValidateManyTuple(const Tuple& tuple) { + TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); + const int64 batch_size = tuple[0].dim_size(0); + if (specified_shapes()) { + for (size_t i = 0; i < tuple.size(); ++i) { + // Expected shape is [batch_size] + component_shapes_[i] + const TensorShape expected_shape = ManyOutShape(i, batch_size); + if (!tuple[i].shape().IsSameSize(expected_shape)) { + return errors::InvalidArgument( + "Shape mismatch in tuple component ", i, ". Expected ", + expected_shape.ShortDebugString(), ", got ", + tuple[i].shape().ShortDebugString()); + } + } + } else { + for (size_t i = 1; i < tuple.size(); ++i) { + if (tuple[i].dim_size(0) != batch_size) { + return errors::InvalidArgument( + "All input tensors must have the same size in the 0th ", + "dimension. Component ", i, " has ", tuple[i].dim_size(0), + ", and should have ", batch_size); + } + } + } + return Status::OK(); +} + +void QueueBase::Cancel(Action action, CancellationToken token) { + DoneCallback callback = nullptr; + { + mutex_lock lock(mu_); + std::deque<Attempt>* attempts = + action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; + + for (Attempt& attempt : *attempts) { + if (attempt.cancellation_token == token) { + attempt.is_cancelled = true; + if (action == kEnqueue) { + attempt.context->SetStatus( + errors::Cancelled("Enqueue operation was cancelled")); + } else { + attempt.context->SetStatus( + errors::Cancelled("Dequeue operation was cancelled")); + } + std::swap(callback, attempt.done_callback); + break; + } + } + } + if (callback) { + callback(); + FlushUnlocked(); + } +} + +void QueueBase::CloseAndCancel() { + std::vector<DoneCallback> callbacks; + { + mutex_lock lock(mu_); + closed_ = true; + for (Attempt& attempt : enqueue_attempts_) { + attempt.is_cancelled = true; + attempt.context->SetStatus( + errors::Cancelled("Enqueue operation was cancelled")); + callbacks.emplace_back(std::move(attempt.done_callback)); + } + } + for (const DoneCallback& callback : callbacks) { + callback(); + } + FlushUnlocked(); +} + +void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues, + DoneCallback callback) { + if (cancel_pending_enqueues) { + CloseAndCancel(); + callback(); + } else { + { + mutex_lock lock(mu_); + enqueue_attempts_.emplace_back( + 0, callback, ctx, CancellationManager::kInvalidToken, + [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (closed_) { + attempt->context->SetStatus( + errors::Aborted("Queue '", name_, "' is already closed.")); + } else { + closed_ = true; + } + return kComplete; + }); + } + FlushUnlocked(); + } +} + +bool QueueBase::TryAttemptLocked(Action action, + std::vector<CleanUp>* clean_up) { + std::deque<Attempt>* attempts = + action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_; + + bool progress = false; + bool done = false; + while (!done && !attempts->empty()) { + if (attempts->front().is_cancelled) { + if (action == kEnqueue) { + LOG(INFO) << "Skipping cancelled enqueue attempt"; + } else { + LOG(INFO) << "Skipping cancelled dequeue attempt"; + } + attempts->pop_front(); + } else { + Attempt* cur_attempt = &attempts->front(); + switch (cur_attempt->run_callback(cur_attempt)) { + case kNoProgress: + done = true; + break; + case kProgress: + done = true; + progress = true; + break; + case kComplete: + progress = true; + clean_up->emplace_back(std::move(cur_attempt->done_callback), + cur_attempt->cancellation_token, + cur_attempt->context->cancellation_manager()); + attempts->pop_front(); + break; + } + } + } + return progress; +} + +void QueueBase::FlushUnlocked() { + std::vector<CleanUp> clean_up; + Ref(); + { + mutex_lock lock(mu_); + bool changed; + do { + changed = TryAttemptLocked(kEnqueue, &clean_up); + changed = TryAttemptLocked(kDequeue, &clean_up) || changed; + } while (changed); + } + Unref(); + for (const auto& to_clean : clean_up) { + if (to_clean.to_deregister != CancellationManager::kInvalidToken) { + // NOTE(mrry): We can safely ignore the return value of + // DeregisterCallback because the mutex mu_ ensures that the + // cleanup action only executes once. + to_clean.cm->DeregisterCallback(to_clean.to_deregister); + } + to_clean.finished(); + } +} + +// Static method +Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, + int index) { +#define HANDLE_TYPE(DT) \ + if (parent.dtype() == DT) { \ + TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \ + return Status::OK(); \ + } + HANDLE_TYPE(DT_FLOAT); + HANDLE_TYPE(DT_DOUBLE); + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_UINT8); + HANDLE_TYPE(DT_INT16); + HANDLE_TYPE(DT_INT8); + HANDLE_TYPE(DT_STRING); + HANDLE_TYPE(DT_INT64); +#undef HANDLE_TYPE + return errors::Unimplemented("Unhandled data type: ", parent.dtype()); +} + +// Static method +Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, + int index) { +#define HANDLE_TYPE(DT) \ + if (element.dtype() == DT) { \ + TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \ + return Status::OK(); \ + } + HANDLE_TYPE(DT_FLOAT); + HANDLE_TYPE(DT_DOUBLE); + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_UINT8); + HANDLE_TYPE(DT_INT16); + HANDLE_TYPE(DT_INT8); + HANDLE_TYPE(DT_STRING); + HANDLE_TYPE(DT_INT64); +#undef HANDLE_TYPE + return errors::Unimplemented("Unhandled data type: ", element.dtype()); +} + } // namespace tensorflow |