aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/text_line_reader_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/text_line_reader_op.cc')
-rw-r--r--tensorflow/core/kernels/text_line_reader_op.cc99
1 files changed, 99 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/text_line_reader_op.cc b/tensorflow/core/kernels/text_line_reader_op.cc
new file mode 100644
index 0000000000..51e4d6a2b8
--- /dev/null
+++ b/tensorflow/core/kernels/text_line_reader_op.cc
@@ -0,0 +1,99 @@
+// See docs in ../ops/io_ops.cc.
+
+#include <memory>
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/kernels/reader_base.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+
+class TextLineReader : public ReaderBase {
+ public:
+ TextLineReader(const string& node_name, int skip_header_lines, Env* env)
+ : ReaderBase(strings::StrCat("TextLineReader '", node_name, "'")),
+ skip_header_lines_(skip_header_lines),
+ env_(env),
+ line_number_(0) {}
+
+ Status OnWorkStartedLocked() override {
+ line_number_ = 0;
+ RandomAccessFile* file = nullptr;
+ TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
+ input_buffer_.reset(new io::InputBuffer(file, kBufferSize));
+ for (; line_number_ < skip_header_lines_; ++line_number_) {
+ string line_contents;
+ Status status = input_buffer_->ReadLine(&line_contents);
+ if (errors::IsOutOfRange(status)) {
+ // We ignore an end of file error when skipping header lines.
+ // We will end up skipping this file.
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(status);
+ }
+ return Status::OK();
+ }
+
+ Status OnWorkFinishedLocked() override {
+ input_buffer_.reset(nullptr);
+ return Status::OK();
+ }
+
+ Status ReadLocked(string* key, string* value, bool* produced,
+ bool* at_end) override {
+ Status status = input_buffer_->ReadLine(value);
+ ++line_number_;
+ if (status.ok()) {
+ *key = strings::StrCat(current_work(), ":", line_number_);
+ *produced = true;
+ return status;
+ }
+ if (errors::IsOutOfRange(status)) { // End of file, advance to the next.
+ *at_end = true;
+ return Status::OK();
+ } else { // Some other reading error
+ return status;
+ }
+ }
+
+ Status ResetLocked() override {
+ line_number_ = 0;
+ input_buffer_.reset(nullptr);
+ return ReaderBase::ResetLocked();
+ }
+
+ // TODO(josh11b): Implement serializing and restoring the state. Need
+ // to create TextLineReaderState proto to store ReaderBaseState,
+ // line_number_, and input_buffer_->Tell().
+
+ private:
+ enum { kBufferSize = 256 << 10 /* 256 kB */ };
+ const int skip_header_lines_;
+ Env* const env_;
+ int64 line_number_;
+ std::unique_ptr<io::InputBuffer> input_buffer_;
+};
+
+class TextLineReaderOp : public ReaderOpKernel {
+ public:
+ explicit TextLineReaderOp(OpKernelConstruction* context)
+ : ReaderOpKernel(context) {
+ int skip_header_lines = -1;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("skip_header_lines", &skip_header_lines));
+ OP_REQUIRES(context, skip_header_lines >= 0,
+ errors::InvalidArgument("skip_header_lines must be >= 0 not ",
+ skip_header_lines));
+ Env* env = context->env();
+ SetReaderFactory([this, skip_header_lines, env]() {
+ return new TextLineReader(name(), skip_header_lines, env);
+ });
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
+ TextLineReaderOp);
+
+} // namespace tensorflow