diff options
-rw-r--r-- | tensorflow/core/framework/reader_interface.h | 12 | ||||
-rw-r--r-- | tensorflow/core/kernels/reader_base.cc | 67 | ||||
-rw-r--r-- | tensorflow/core/kernels/reader_base.h | 14 | ||||
-rw-r--r-- | tensorflow/core/kernels/reader_ops.cc | 52 | ||||
-rw-r--r-- | tensorflow/core/ops/io_ops.cc | 20 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/reader_ops_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/io_ops.py | 36 |
8 files changed, 214 insertions, 0 deletions
diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h index f0639b1c4e..dac6056b5a 100644 --- a/tensorflow/core/framework/reader_interface.h +++ b/tensorflow/core/framework/reader_interface.h @@ -51,6 +51,18 @@ class ReaderInterface : public ResourceBase { virtual void Read(QueueInterface* queue, string* key, string* value, OpKernelContext* context) = 0; + // Read up to num_records records into keys / values. May get more work from + // *queue if the current work is complete. Sets the status on + // *context with an OutOfRange Status if the current work is + // complete and the queue is done (closed and empty). + // This method may block. + // The std::vector keys/value pointers are assumed to point to empty + // structures (that have most likely been reserve(num_records)). + // Returns how many records were actually read. + virtual int64 ReadUpTo(const int64 num_records, QueueInterface* queue, + std::vector<string>* keys, std::vector<string>* value, + OpKernelContext* context) = 0; + // Restore this reader to its newly-constructed state. virtual Status Reset() = 0; diff --git a/tensorflow/core/kernels/reader_base.cc b/tensorflow/core/kernels/reader_base.cc index a74e3b2804..0ba6edf328 100644 --- a/tensorflow/core/kernels/reader_base.cc +++ b/tensorflow/core/kernels/reader_base.cc @@ -74,6 +74,69 @@ Status ReaderBase::RestoreStateLocked(const string& state) { return errors::Unimplemented("Reader RestoreState"); } +int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue, + std::vector<string>* keys, + std::vector<string>* values, + OpKernelContext* context) { + mutex_lock lock(mu_); + int64 records_produced_this_call = 0; + while (true) { + // Records produced by this iteration of the ReadUpToLocked call. + int64 num_records_produced = 0; + int64 remaining = num_records - records_produced_this_call; + if (remaining == 0) { + return records_produced_this_call; + } + if (!work_in_progress()) { + GetNextWorkLocked(queue, context); + if (!context->status().ok()) return records_produced_this_call; + } + bool at_end = false; + + Status status = + ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end); + // This call so far. + records_produced_this_call += num_records_produced; + + // In total, over the lifetime of the ReaderBase. + num_records_produced_ += num_records_produced; + + if (!at_end && status.ok() && num_records_produced == 0) { + status = errors::Internal( + "ReadManyLocked() for ", name(), + " must set *at_end=true, *num_produced > 0 or return an error."); + context->SetStatus(status); + return records_produced_this_call; + } + if (status.ok() && at_end) { + status = OnWorkFinishedLocked(); + work_finished_ = work_started_; + } + if (!status.ok()) { + context->SetStatus(status); + return records_produced_this_call; + } + } +} + +// Default implementation just reads one record at a time. +Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys, + std::vector<string>* values, int64* num_read, + bool* at_end) { + bool produced; + string key; + string value; + Status status = ReadLocked(&key, &value, &produced, at_end); + if (produced) { + keys->emplace_back(key); + values->emplace_back(value); + *num_read = 1; + } else { + *num_read = 0; + } + return status; +} + void ReaderBase::Read(QueueInterface* queue, string* key, string* value, OpKernelContext* context) { mutex_lock lock(mu_); @@ -150,6 +213,10 @@ void ReaderBase::SaveBaseState(ReaderBaseState* state) const { state->set_current_work(work_); } +string ReaderBase::KeyName(const string& key) const { + return strings::StrCat(current_work(), ":", key); +} + Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { work_started_ = state.work_started(); work_finished_ = state.work_finished(); diff --git a/tensorflow/core/kernels/reader_base.h b/tensorflow/core/kernels/reader_base.h index 654aa8dca8..90bb61fa6c 100644 --- a/tensorflow/core/kernels/reader_base.h +++ b/tensorflow/core/kernels/reader_base.h @@ -56,6 +56,12 @@ class ReaderBase : public ReaderInterface { // Descendants may optionally implement these ------------------------------- + // Produce up to num_records next key/value pairs from the current + // work item, in the same manner of ReadLocked. + virtual Status ReadUpToLocked(int64 num_records, std::vector<string>* keys, + std::vector<string>* values, int64* num_read, + bool* at_end); + // Called when work starts / finishes. virtual Status OnWorkStartedLocked() { return Status::OK(); } virtual Status OnWorkFinishedLocked() { return Status::OK(); } @@ -81,6 +87,9 @@ class ReaderBase : public ReaderInterface { // What was passed to the constructor. const string& name() const { return name_; } + // Produce the key name (from current_work and the actual key). + string KeyName(const string& key) const; + protected: // For descendants wishing to implement serialize & restore state. @@ -96,6 +105,11 @@ class ReaderBase : public ReaderInterface { // and call the methods above to do the work. void Read(QueueInterface* queue, string* key, string* value, OpKernelContext* context) override; + + int64 ReadUpTo(const int64 num_records, QueueInterface* queue, + std::vector<string>* keys, std::vector<string>* value, + OpKernelContext* context); + Status Reset() override; int64 NumRecordsProduced() override; int64 NumWorkUnitsCompleted() override; diff --git a/tensorflow/core/kernels/reader_ops.cc b/tensorflow/core/kernels/reader_ops.cc index 3960eea677..1c7fbae81c 100644 --- a/tensorflow/core/kernels/reader_ops.cc +++ b/tensorflow/core/kernels/reader_ops.cc @@ -97,6 +97,58 @@ class ReaderReadOp : public ReaderVerbAsyncOpKernel { REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp); +class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel { + public: + using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel; + + void ComputeWithReader(OpKernelContext* context, + ReaderInterface* reader) override { + QueueInterface* queue; + + const Tensor* num_records_tensor; + OP_REQUIRES_OK(context, context->input("num_records", &num_records_tensor)); + int64 num_records = num_records_tensor->scalar<int64>()(); + + OP_REQUIRES_OK(context, + GetResourceFromContext(context, "queue_handle", &queue)); + core::ScopedUnref unref_me(queue); + + std::vector<string> keys_vec; + keys_vec.reserve(num_records); + std::vector<string> values_vec; + values_vec.reserve(num_records); + + int64 num_actually_read = + reader->ReadUpTo(num_records, queue, &keys_vec, &values_vec, context); + + OP_REQUIRES(context, num_actually_read == keys_vec.size(), + errors::InvalidArgument("num_actually_read != len(keys_vec")); + + OP_REQUIRES(context, num_actually_read == values_vec.size(), + errors::InvalidArgument("num_actually_read != len(values_vec")); + + Tensor* keys = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + "keys", TensorShape({num_actually_read}), &keys)); + + Tensor* values = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + "values", TensorShape({num_actually_read}), &values)); + + auto keys_t = keys->vec<string>(); + auto values_t = values->vec<string>(); + for (int i = 0; i < num_actually_read; ++i) { + keys_t(i) = std::move(keys_vec[i]); + values_t(i) = std::move(values_vec[i]); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo").Device(DEVICE_CPU), + ReaderReadUpToOp); + class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel { public: using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index 09ffee0604..9b0fe4f2d5 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -266,6 +266,26 @@ key: A scalar. value: A scalar. )doc"); +REGISTER_OP("ReaderReadUpTo") + .Input("reader_handle: Ref(string)") + .Input("queue_handle: Ref(string)") + .Input("num_records: int64") + .Output("keys: string") + .Output("values: string") + .Doc(R"doc( +Returns up to `num_records` (key, value) pairs produced by a Reader. + +Will dequeue from the input queue if necessary (e.g. when the +Reader needs to start reading from a new file since it has finished +with the previous file). + +reader_handle: Handle to a `Reader`. +queue_handle: Handle to a `Queue`, with string work items. +num_records: number of records to read from `Reader`. +keys: A 1-D tensor. +values: A 1-D tensor. +)doc"); + REGISTER_OP("ReaderNumRecordsProduced") .Input("reader_handle: Ref(string)") .Output("records_produced: int64") diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 91248812f4..2e9d5ec2c9 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -599,6 +599,7 @@ tf_gen_op_wrapper_py( "ReaderNumRecordsProduced", "ReaderNumWorkUnitsCompleted", "ReaderRead", + "ReaderReadUpTo", "ReaderReset", "ReaderRestoreState", "ReaderSerializeState", diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 1b55f02285..c49ff3fb49 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -32,6 +32,18 @@ class IdentityReaderTest(tf.test.TestCase): self.assertAllEqual(expected, k) self.assertAllEqual(expected, v) + def testReadUpTo(self): + # Note that this just tests the default ReaderReadUpTo + # since it is not overriden for IdentityReader. + with self.test_session() as sess: + reader = tf.IdentityReader("test_reader") + queue = tf.FIFOQueue(99, [tf.string], shapes=()) + keys, values = reader.read_up_to(queue, 3) + queue.enqueue_many([["A", "B", "C"]]).run() + k, v = sess.run([keys, values]) + self.assertAllEqual([b"A", b"B", b"C"], k) + self.assertAllEqual([b"A", b"B", b"C"], v) + def testOneEpoch(self): with self.test_session() as sess: reader = tf.IdentityReader("test_reader") diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 97fe48bab7..757aeb5a53 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -332,6 +332,34 @@ class ReaderBase(object): queue_ref = queue.queue_ref return gen_io_ops._reader_read(self._reader_ref, queue_ref, name=name) + def read_up_to(self, queue, num_records, # pylint: disable=invalid-name + name=None): + """Returns up to num_records (key, value pairs) produced by a reader. + + Will dequeue a work unit from queue if necessary (e.g., when the + Reader needs to start reading from a new file since it has + finished with the previous file). + + Args: + queue: A Queue or a mutable string Tensor representing a handle + to a Queue, with string work items. + num_records: Number of records to read. + name: A name for the operation (optional). + + Returns: + A tuple of Tensors (keys, values). + keys: A 1-D string Tensor. + values: A 1-D string Tensor. + """ + if isinstance(queue, ops.Tensor): + queue_ref = queue + else: + queue_ref = queue.queue_ref + return gen_io_ops._reader_read_up_to(self._reader_ref, + queue_ref, + num_records, + name=name) + def num_records_produced(self, name=None): """Returns the number of records this reader has produced. @@ -407,6 +435,7 @@ class ReaderBase(object): ops.NoGradient("ReaderRead") +ops.NoGradient("ReaderReadUpTo") ops.NoGradient("ReaderNumRecordsProduced") ops.NoGradient("ReaderNumWorkUnitsCompleted") ops.NoGradient("ReaderSerializeState") @@ -555,6 +584,13 @@ def _ReaderReadShape(op): return [tensor_shape.scalar(), tensor_shape.scalar()] +@ops.RegisterShape("ReaderReadUpTo") +def _ReaderReadUpToShape(_): + """Shape function for the ReaderBase.ReadUpTo op.""" + return [tensor_shape.unknown_shape(ndims=1), + tensor_shape.unknown_shape(ndims=1)] + + @ops.RegisterShape("ReaderReset") def _ReaderResetShape(op): """Shape function for the ReaderBase.Reset op.""" |