diff options
Diffstat (limited to 'tensorflow/core/kernels/tf_record_reader_op.cc')
-rw-r--r-- | tensorflow/core/kernels/tf_record_reader_op.cc | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc new file mode 100644 index 0000000000..551be18d5f --- /dev/null +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -0,0 +1,76 @@ +// See docs in ../ops/io_ops.cc. + +#include <memory> +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/kernels/reader_base.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/record_reader.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/env.h" + +namespace tensorflow { + +class TFRecordReader : public ReaderBase { + public: + TFRecordReader(const string& node_name, Env* env) + : ReaderBase(strings::StrCat("TFRecordReader '", node_name, "'")), + env_(env), + offset_(0) {} + + Status OnWorkStartedLocked() override { + offset_ = 0; + RandomAccessFile* file = nullptr; + TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file)); + file_.reset(file); + reader_.reset(new io::RecordReader(file)); + return Status::OK(); + } + + Status OnWorkFinishedLocked() override { + reader_.reset(nullptr); + file_.reset(nullptr); + return Status::OK(); + } + + Status ReadLocked(string* key, string* value, bool* produced, + bool* at_end) override { + *key = strings::StrCat(current_work(), ":", offset_); + Status status = reader_->ReadRecord(&offset_, value); + if (errors::IsOutOfRange(status)) { + *at_end = true; + return Status::OK(); + } + if (!status.ok()) return status; + *produced = true; + return Status::OK(); + } + + Status ResetLocked() override { + offset_ = 0; + reader_.reset(nullptr); + file_.reset(nullptr); + return ReaderBase::ResetLocked(); + } + + // TODO(josh11b): Implement serializing and restoring the state. + + private: + Env* const env_; + uint64 offset_; + std::unique_ptr<RandomAccessFile> file_; + std::unique_ptr<io::RecordReader> reader_; +}; + +class TFRecordReaderOp : public ReaderOpKernel { + public: + explicit TFRecordReaderOp(OpKernelConstruction* context) + : ReaderOpKernel(context) { + Env* env = context->env(); + SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); }); + } +}; + +REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU), + TFRecordReaderOp); + +} // namespace tensorflow |