aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/random_shuffle_queue_op.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@gmail.com>2016-04-20 07:53:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-20 09:01:21 -0700
commitb1a6216a82c78bec2ed9c881c51629eb1fa4a7ee (patch)
tree385127b14bfa27f9b22824827d4242403dc64096 /tensorflow/core/kernels/random_shuffle_queue_op.cc
parentd1a31025aaaecbc9137998b587b90221c5a36cb3 (diff)
Add allow_small_batch attribute to QueueInterface, and a new op
called DequeueUpToOp for Queues. In python land, there is a new Queue.dequeue_up_to method. No queues support this dequeue option for now. If a user calls dequeue_up_to, an error is currently returned at runtime. Change: 120341224
Diffstat (limited to 'tensorflow/core/kernels/random_shuffle_queue_op.cc')
-rw-r--r--tensorflow/core/kernels/random_shuffle_queue_op.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
index 1c96d6234c..a002328869 100644
--- a/tensorflow/core/kernels/random_shuffle_queue_op.cc
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -53,6 +53,7 @@ class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > {
DoneCallback callback) override;
void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
void TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) override;
Status MatchesNodeDef(const NodeDef& node_def) override;
@@ -256,7 +257,15 @@ void RandomShuffleQueue::TryDequeue(OpKernelContext* ctx,
}
void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
+ bool allow_small_batch,
CallbackWithTuple callback) {
+ if (allow_small_batch) {
+ ctx->SetStatus(
+ errors::Unimplemented("Dequeue: Queue does not support small batches"));
+ callback(Tuple());
+ return;
+ }
+
if (!specified_shapes()) {
ctx->SetStatus(
errors::InvalidArgument("RandomShuffleQueue's DequeueMany requires the "