aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/queue_base.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-12 11:27:00 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-12 11:27:00 -0800
commit4dffee7f62d81ec9173aba1b0ef6b96e47f8037c (patch)
tree4b5c04b37afe45fdc5f1729252514a2770fbf1ab /tensorflow/core/kernels/queue_base.cc
parentf2102f4e2c1c87f1d1bf9ab856a2849c54478760 (diff)
TensorFlow: Minor updates to docs, BUILD, GPU config / perf, etc.
Changes: - Updates to op documentation and index by Josh - More changes to BUILD files for python 3 support by @girving - Fix to Eigen to use DenseIndex everywhere by @jiayq - Enable configuration for cuda compute capability by @zheng-xq, including updates to docs. - Route aggregation method through optimizer by schuster - Updates to install instructions for bazel 0.1.1. Base CL: 107702099
Diffstat (limited to 'tensorflow/core/kernels/queue_base.cc')
-rw-r--r--tensorflow/core/kernels/queue_base.cc265
1 files changed, 222 insertions, 43 deletions
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
index 4217b9ce86..d0f47505c4 100644
--- a/tensorflow/core/kernels/queue_base.cc
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -46,52 +46,14 @@ Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
} // namespace
-// static
-Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
- int index) {
-#define HANDLE_TYPE(DT) \
- if (parent.dtype() == DT) { \
- TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
- return Status::OK(); \
- }
- HANDLE_TYPE(DT_FLOAT);
- HANDLE_TYPE(DT_DOUBLE);
- HANDLE_TYPE(DT_INT32);
- HANDLE_TYPE(DT_UINT8);
- HANDLE_TYPE(DT_INT16);
- HANDLE_TYPE(DT_INT8);
- HANDLE_TYPE(DT_STRING);
- HANDLE_TYPE(DT_INT64);
-#undef HANDLE_TYPE
- return errors::Unimplemented("Unhandled data type: ", parent.dtype());
-}
-
-// static
-Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
- int index) {
-#define HANDLE_TYPE(DT) \
- if (element.dtype() == DT) { \
- TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
- return Status::OK(); \
- }
- HANDLE_TYPE(DT_FLOAT);
- HANDLE_TYPE(DT_DOUBLE);
- HANDLE_TYPE(DT_INT32);
- HANDLE_TYPE(DT_UINT8);
- HANDLE_TYPE(DT_INT16);
- HANDLE_TYPE(DT_INT8);
- HANDLE_TYPE(DT_STRING);
- HANDLE_TYPE(DT_INT64);
-#undef HANDLE_TYPE
- return errors::Unimplemented("Unhandled data type: ", element.dtype());
-}
-
-QueueBase::QueueBase(const DataTypeVector& component_dtypes,
+QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
const std::vector<TensorShape>& component_shapes,
const string& name)
- : component_dtypes_(component_dtypes),
+ : capacity_(capacity),
+ component_dtypes_(component_dtypes),
component_shapes_(component_shapes),
- name_(name) {}
+ name_(name),
+ closed_(false) {}
Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
if (tuple.size() != static_cast<size_t>(num_components())) {
@@ -172,4 +134,221 @@ Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
return Status::OK();
}
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status QueueBase::ValidateTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ component_shapes_[i].ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+// reduce the number of times that they are called.
+Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
+ TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+ const int64 batch_size = tuple[0].dim_size(0);
+ if (specified_shapes()) {
+ for (size_t i = 0; i < tuple.size(); ++i) {
+ // Expected shape is [batch_size] + component_shapes_[i]
+ const TensorShape expected_shape = ManyOutShape(i, batch_size);
+ if (!tuple[i].shape().IsSameSize(expected_shape)) {
+ return errors::InvalidArgument(
+ "Shape mismatch in tuple component ", i, ". Expected ",
+ expected_shape.ShortDebugString(), ", got ",
+ tuple[i].shape().ShortDebugString());
+ }
+ }
+ } else {
+ for (size_t i = 1; i < tuple.size(); ++i) {
+ if (tuple[i].dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All input tensors must have the same size in the 0th ",
+ "dimension. Component ", i, " has ", tuple[i].dim_size(0),
+ ", and should have ", batch_size);
+ }
+ }
+ }
+ return Status::OK();
+}
+
+void QueueBase::Cancel(Action action, CancellationToken token) {
+ DoneCallback callback = nullptr;
+ {
+ mutex_lock lock(mu_);
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ for (Attempt& attempt : *attempts) {
+ if (attempt.cancellation_token == token) {
+ attempt.is_cancelled = true;
+ if (action == kEnqueue) {
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ } else {
+ attempt.context->SetStatus(
+ errors::Cancelled("Dequeue operation was cancelled"));
+ }
+ std::swap(callback, attempt.done_callback);
+ break;
+ }
+ }
+ }
+ if (callback) {
+ callback();
+ FlushUnlocked();
+ }
+}
+
+void QueueBase::CloseAndCancel() {
+ std::vector<DoneCallback> callbacks;
+ {
+ mutex_lock lock(mu_);
+ closed_ = true;
+ for (Attempt& attempt : enqueue_attempts_) {
+ attempt.is_cancelled = true;
+ attempt.context->SetStatus(
+ errors::Cancelled("Enqueue operation was cancelled"));
+ callbacks.emplace_back(std::move(attempt.done_callback));
+ }
+ }
+ for (const DoneCallback& callback : callbacks) {
+ callback();
+ }
+ FlushUnlocked();
+}
+
+void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+ DoneCallback callback) {
+ if (cancel_pending_enqueues) {
+ CloseAndCancel();
+ callback();
+ } else {
+ {
+ mutex_lock lock(mu_);
+ enqueue_attempts_.emplace_back(
+ 0, callback, ctx, CancellationManager::kInvalidToken,
+ [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (closed_) {
+ attempt->context->SetStatus(
+ errors::Aborted("Queue '", name_, "' is already closed."));
+ } else {
+ closed_ = true;
+ }
+ return kComplete;
+ });
+ }
+ FlushUnlocked();
+ }
+}
+
+bool QueueBase::TryAttemptLocked(Action action,
+ std::vector<CleanUp>* clean_up) {
+ std::deque<Attempt>* attempts =
+ action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+ bool progress = false;
+ bool done = false;
+ while (!done && !attempts->empty()) {
+ if (attempts->front().is_cancelled) {
+ if (action == kEnqueue) {
+ LOG(INFO) << "Skipping cancelled enqueue attempt";
+ } else {
+ LOG(INFO) << "Skipping cancelled dequeue attempt";
+ }
+ attempts->pop_front();
+ } else {
+ Attempt* cur_attempt = &attempts->front();
+ switch (cur_attempt->run_callback(cur_attempt)) {
+ case kNoProgress:
+ done = true;
+ break;
+ case kProgress:
+ done = true;
+ progress = true;
+ break;
+ case kComplete:
+ progress = true;
+ clean_up->emplace_back(std::move(cur_attempt->done_callback),
+ cur_attempt->cancellation_token,
+ cur_attempt->context->cancellation_manager());
+ attempts->pop_front();
+ break;
+ }
+ }
+ }
+ return progress;
+}
+
+void QueueBase::FlushUnlocked() {
+ std::vector<CleanUp> clean_up;
+ Ref();
+ {
+ mutex_lock lock(mu_);
+ bool changed;
+ do {
+ changed = TryAttemptLocked(kEnqueue, &clean_up);
+ changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
+ } while (changed);
+ }
+ Unref();
+ for (const auto& to_clean : clean_up) {
+ if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
+ // NOTE(mrry): We can safely ignore the return value of
+ // DeregisterCallback because the mutex mu_ ensures that the
+ // cleanup action only executes once.
+ to_clean.cm->DeregisterCallback(to_clean.to_deregister);
+ }
+ to_clean.finished();
+ }
+}
+
+// Static method
+Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
+ int index) {
+#define HANDLE_TYPE(DT) \
+ if (parent.dtype() == DT) { \
+ TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
+ return Status::OK(); \
+ }
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_UINT8);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT8);
+ HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("Unhandled data type: ", parent.dtype());
+}
+
+// Static method
+Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
+ int index) {
+#define HANDLE_TYPE(DT) \
+ if (element.dtype() == DT) { \
+ TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
+ return Status::OK(); \
+ }
+ HANDLE_TYPE(DT_FLOAT);
+ HANDLE_TYPE(DT_DOUBLE);
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_UINT8);
+ HANDLE_TYPE(DT_INT16);
+ HANDLE_TYPE(DT_INT8);
+ HANDLE_TYPE(DT_STRING);
+ HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+ return errors::Unimplemented("Unhandled data type: ", element.dtype());
+}
+
} // namespace tensorflow