aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/reader_interface.h12
-rw-r--r--tensorflow/core/kernels/reader_base.cc67
-rw-r--r--tensorflow/core/kernels/reader_base.h14
-rw-r--r--tensorflow/core/kernels/reader_ops.cc52
-rw-r--r--tensorflow/core/ops/io_ops.cc20
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py12
-rw-r--r--tensorflow/python/ops/io_ops.py36
8 files changed, 214 insertions, 0 deletions
diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h
index f0639b1c4e..dac6056b5a 100644
--- a/tensorflow/core/framework/reader_interface.h
+++ b/tensorflow/core/framework/reader_interface.h
@@ -51,6 +51,18 @@ class ReaderInterface : public ResourceBase {
virtual void Read(QueueInterface* queue, string* key, string* value,
OpKernelContext* context) = 0;
+ // Read up to num_records records into keys / values. May get more work from
+ // *queue if the current work is complete. Sets the status on
+ // *context with an OutOfRange Status if the current work is
+ // complete and the queue is done (closed and empty).
+ // This method may block.
+ // The std::vector keys/value pointers are assumed to point to empty
+ // structures (that have most likely been reserve(num_records)).
+ // Returns how many records were actually read.
+ virtual int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
+ std::vector<string>* keys, std::vector<string>* value,
+ OpKernelContext* context) = 0;
+
// Restore this reader to its newly-constructed state.
virtual Status Reset() = 0;
diff --git a/tensorflow/core/kernels/reader_base.cc b/tensorflow/core/kernels/reader_base.cc
index a74e3b2804..0ba6edf328 100644
--- a/tensorflow/core/kernels/reader_base.cc
+++ b/tensorflow/core/kernels/reader_base.cc
@@ -74,6 +74,69 @@ Status ReaderBase::RestoreStateLocked(const string& state) {
return errors::Unimplemented("Reader RestoreState");
}
+int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
+ std::vector<string>* keys,
+ std::vector<string>* values,
+ OpKernelContext* context) {
+ mutex_lock lock(mu_);
+ int64 records_produced_this_call = 0;
+ while (true) {
+ // Records produced by this iteration of the ReadUpToLocked call.
+ int64 num_records_produced = 0;
+ int64 remaining = num_records - records_produced_this_call;
+ if (remaining == 0) {
+ return records_produced_this_call;
+ }
+ if (!work_in_progress()) {
+ GetNextWorkLocked(queue, context);
+ if (!context->status().ok()) return records_produced_this_call;
+ }
+ bool at_end = false;
+
+ Status status =
+ ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end);
+ // This call so far.
+ records_produced_this_call += num_records_produced;
+
+ // In total, over the lifetime of the ReaderBase.
+ num_records_produced_ += num_records_produced;
+
+ if (!at_end && status.ok() && num_records_produced == 0) {
+ status = errors::Internal(
+ "ReadManyLocked() for ", name(),
+ " must set *at_end=true, *num_produced > 0 or return an error.");
+ context->SetStatus(status);
+ return records_produced_this_call;
+ }
+ if (status.ok() && at_end) {
+ status = OnWorkFinishedLocked();
+ work_finished_ = work_started_;
+ }
+ if (!status.ok()) {
+ context->SetStatus(status);
+ return records_produced_this_call;
+ }
+ }
+}
+
+// Default implementation just reads one record at a time.
+Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys,
+ std::vector<string>* values, int64* num_read,
+ bool* at_end) {
+ bool produced;
+ string key;
+ string value;
+ Status status = ReadLocked(&key, &value, &produced, at_end);
+ if (produced) {
+ keys->emplace_back(key);
+ values->emplace_back(value);
+ *num_read = 1;
+ } else {
+ *num_read = 0;
+ }
+ return status;
+}
+
void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
OpKernelContext* context) {
mutex_lock lock(mu_);
@@ -150,6 +213,10 @@ void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
state->set_current_work(work_);
}
+string ReaderBase::KeyName(const string& key) const {
+ return strings::StrCat(current_work(), ":", key);
+}
+
Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
work_started_ = state.work_started();
work_finished_ = state.work_finished();
diff --git a/tensorflow/core/kernels/reader_base.h b/tensorflow/core/kernels/reader_base.h
index 654aa8dca8..90bb61fa6c 100644
--- a/tensorflow/core/kernels/reader_base.h
+++ b/tensorflow/core/kernels/reader_base.h
@@ -56,6 +56,12 @@ class ReaderBase : public ReaderInterface {
// Descendants may optionally implement these -------------------------------
+ // Produce up to num_records next key/value pairs from the current
+ // work item, in the same manner of ReadLocked.
+ virtual Status ReadUpToLocked(int64 num_records, std::vector<string>* keys,
+ std::vector<string>* values, int64* num_read,
+ bool* at_end);
+
// Called when work starts / finishes.
virtual Status OnWorkStartedLocked() { return Status::OK(); }
virtual Status OnWorkFinishedLocked() { return Status::OK(); }
@@ -81,6 +87,9 @@ class ReaderBase : public ReaderInterface {
// What was passed to the constructor.
const string& name() const { return name_; }
+ // Produce the key name (from current_work and the actual key).
+ string KeyName(const string& key) const;
+
protected:
// For descendants wishing to implement serialize & restore state.
@@ -96,6 +105,11 @@ class ReaderBase : public ReaderInterface {
// and call the methods above to do the work.
void Read(QueueInterface* queue, string* key, string* value,
OpKernelContext* context) override;
+
+ int64 ReadUpTo(const int64 num_records, QueueInterface* queue,
+ std::vector<string>* keys, std::vector<string>* value,
+ OpKernelContext* context);
+
Status Reset() override;
int64 NumRecordsProduced() override;
int64 NumWorkUnitsCompleted() override;
diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc
index 3960eea677..1c7fbae81c 100644
--- a/tensorflow/core/kernels/reader_ops.cc
+++ b/tensorflow/core/kernels/reader_ops.cc
@@ -97,6 +97,58 @@ class ReaderReadOp : public ReaderVerbAsyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp);
+class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel {
+ public:
+ using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel;
+
+ void ComputeWithReader(OpKernelContext* context,
+ ReaderInterface* reader) override {
+ QueueInterface* queue;
+
+ const Tensor* num_records_tensor;
+ OP_REQUIRES_OK(context, context->input("num_records", &num_records_tensor));
+ int64 num_records = num_records_tensor->scalar<int64>()();
+
+ OP_REQUIRES_OK(context,
+ GetResourceFromContext(context, "queue_handle", &queue));
+ core::ScopedUnref unref_me(queue);
+
+ std::vector<string> keys_vec;
+ keys_vec.reserve(num_records);
+ std::vector<string> values_vec;
+ values_vec.reserve(num_records);
+
+ int64 num_actually_read =
+ reader->ReadUpTo(num_records, queue, &keys_vec, &values_vec, context);
+
+ OP_REQUIRES(context, num_actually_read == keys_vec.size(),
+ errors::InvalidArgument("num_actually_read != len(keys_vec"));
+
+ OP_REQUIRES(context, num_actually_read == values_vec.size(),
+ errors::InvalidArgument("num_actually_read != len(values_vec"));
+
+ Tensor* keys = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ "keys", TensorShape({num_actually_read}), &keys));
+
+ Tensor* values = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ "values", TensorShape({num_actually_read}), &values));
+
+ auto keys_t = keys->vec<string>();
+ auto values_t = values->vec<string>();
+ for (int i = 0; i < num_actually_read; ++i) {
+ keys_t(i) = std::move(keys_vec[i]);
+ values_t(i) = std::move(values_vec[i]);
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo").Device(DEVICE_CPU),
+ ReaderReadUpToOp);
+
class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel {
public:
using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 09ffee0604..9b0fe4f2d5 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -266,6 +266,26 @@ key: A scalar.
value: A scalar.
)doc");
+REGISTER_OP("ReaderReadUpTo")
+ .Input("reader_handle: Ref(string)")
+ .Input("queue_handle: Ref(string)")
+ .Input("num_records: int64")
+ .Output("keys: string")
+ .Output("values: string")
+ .Doc(R"doc(
+Returns up to `num_records` (key, value) pairs produced by a Reader.
+
+Will dequeue from the input queue if necessary (e.g. when the
+Reader needs to start reading from a new file since it has finished
+with the previous file).
+
+reader_handle: Handle to a `Reader`.
+queue_handle: Handle to a `Queue`, with string work items.
+num_records: number of records to read from `Reader`.
+keys: A 1-D tensor.
+values: A 1-D tensor.
+)doc");
+
REGISTER_OP("ReaderNumRecordsProduced")
.Input("reader_handle: Ref(string)")
.Output("records_produced: int64")
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 91248812f4..2e9d5ec2c9 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -599,6 +599,7 @@ tf_gen_op_wrapper_py(
"ReaderNumRecordsProduced",
"ReaderNumWorkUnitsCompleted",
"ReaderRead",
+ "ReaderReadUpTo",
"ReaderReset",
"ReaderRestoreState",
"ReaderSerializeState",
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 1b55f02285..c49ff3fb49 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -32,6 +32,18 @@ class IdentityReaderTest(tf.test.TestCase):
self.assertAllEqual(expected, k)
self.assertAllEqual(expected, v)
+ def testReadUpTo(self):
+ # Note that this just tests the default ReaderReadUpTo
+ # since it is not overriden for IdentityReader.
+ with self.test_session() as sess:
+ reader = tf.IdentityReader("test_reader")
+ queue = tf.FIFOQueue(99, [tf.string], shapes=())
+ keys, values = reader.read_up_to(queue, 3)
+ queue.enqueue_many([["A", "B", "C"]]).run()
+ k, v = sess.run([keys, values])
+ self.assertAllEqual([b"A", b"B", b"C"], k)
+ self.assertAllEqual([b"A", b"B", b"C"], v)
+
def testOneEpoch(self):
with self.test_session() as sess:
reader = tf.IdentityReader("test_reader")
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 97fe48bab7..757aeb5a53 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -332,6 +332,34 @@ class ReaderBase(object):
queue_ref = queue.queue_ref
return gen_io_ops._reader_read(self._reader_ref, queue_ref, name=name)
+ def read_up_to(self, queue, num_records, # pylint: disable=invalid-name
+ name=None):
+ """Returns up to num_records (key, value pairs) produced by a reader.
+
+ Will dequeue a work unit from queue if necessary (e.g., when the
+ Reader needs to start reading from a new file since it has
+ finished with the previous file).
+
+ Args:
+ queue: A Queue or a mutable string Tensor representing a handle
+ to a Queue, with string work items.
+ num_records: Number of records to read.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tuple of Tensors (keys, values).
+ keys: A 1-D string Tensor.
+ values: A 1-D string Tensor.
+ """
+ if isinstance(queue, ops.Tensor):
+ queue_ref = queue
+ else:
+ queue_ref = queue.queue_ref
+ return gen_io_ops._reader_read_up_to(self._reader_ref,
+ queue_ref,
+ num_records,
+ name=name)
+
def num_records_produced(self, name=None):
"""Returns the number of records this reader has produced.
@@ -407,6 +435,7 @@ class ReaderBase(object):
ops.NoGradient("ReaderRead")
+ops.NoGradient("ReaderReadUpTo")
ops.NoGradient("ReaderNumRecordsProduced")
ops.NoGradient("ReaderNumWorkUnitsCompleted")
ops.NoGradient("ReaderSerializeState")
@@ -555,6 +584,13 @@ def _ReaderReadShape(op):
return [tensor_shape.scalar(), tensor_shape.scalar()]
+@ops.RegisterShape("ReaderReadUpTo")
+def _ReaderReadUpToShape(_):
+ """Shape function for the ReaderBase.ReadUpTo op."""
+ return [tensor_shape.unknown_shape(ndims=1),
+ tensor_shape.unknown_shape(ndims=1)]
+
+
@ops.RegisterShape("ReaderReset")
def _ReaderResetShape(op):
"""Shape function for the ReaderBase.Reset op."""