1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
// See docs in ../ops/io_ops.cc.
#include <memory>
#include "tensorflow/core/framework/reader_op_kernel.h"
#include "tensorflow/core/kernels/reader_base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/env.h"
namespace tensorflow {
class TextLineReader : public ReaderBase {
public:
TextLineReader(const string& node_name, int skip_header_lines, Env* env)
: ReaderBase(strings::StrCat("TextLineReader '", node_name, "'")),
skip_header_lines_(skip_header_lines),
env_(env),
line_number_(0) {}
Status OnWorkStartedLocked() override {
line_number_ = 0;
RandomAccessFile* file = nullptr;
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
input_buffer_.reset(new io::InputBuffer(file, kBufferSize));
for (; line_number_ < skip_header_lines_; ++line_number_) {
string line_contents;
Status status = input_buffer_->ReadLine(&line_contents);
if (errors::IsOutOfRange(status)) {
// We ignore an end of file error when skipping header lines.
// We will end up skipping this file.
return Status::OK();
}
TF_RETURN_IF_ERROR(status);
}
return Status::OK();
}
Status OnWorkFinishedLocked() override {
input_buffer_.reset(nullptr);
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
bool* at_end) override {
Status status = input_buffer_->ReadLine(value);
++line_number_;
if (status.ok()) {
*key = strings::StrCat(current_work(), ":", line_number_);
*produced = true;
return status;
}
if (errors::IsOutOfRange(status)) { // End of file, advance to the next.
*at_end = true;
return Status::OK();
} else { // Some other reading error
return status;
}
}
Status ResetLocked() override {
line_number_ = 0;
input_buffer_.reset(nullptr);
return ReaderBase::ResetLocked();
}
// TODO(josh11b): Implement serializing and restoring the state. Need
// to create TextLineReaderState proto to store ReaderBaseState,
// line_number_, and input_buffer_->Tell().
private:
enum { kBufferSize = 256 << 10 /* 256 kB */ };
const int skip_header_lines_;
Env* const env_;
int64 line_number_;
std::unique_ptr<io::InputBuffer> input_buffer_;
};
class TextLineReaderOp : public ReaderOpKernel {
public:
explicit TextLineReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
int skip_header_lines = -1;
OP_REQUIRES_OK(context,
context->GetAttr("skip_header_lines", &skip_header_lines));
OP_REQUIRES(context, skip_header_lines >= 0,
errors::InvalidArgument("skip_header_lines must be >= 0 not ",
skip_header_lines));
Env* env = context->env();
SetReaderFactory([this, skip_header_lines, env]() {
return new TextLineReader(name(), skip_header_lines, env);
});
}
};
REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
TextLineReaderOp);
} // namespace tensorflow
|