1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
// See docs in ../ops/io_ops.cc.
#include <memory>
#include "tensorflow/core/framework/reader_op_kernel.h"
#include "tensorflow/core/kernels/reader_base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/inputbuffer.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/public/env.h"
namespace tensorflow {
class FixedLengthRecordReader : public ReaderBase {
public:
FixedLengthRecordReader(const string& node_name, int64 header_bytes,
int64 record_bytes, int64 footer_bytes, Env* env)
: ReaderBase(
strings::StrCat("FixedLengthRecordReader '", node_name, "'")),
header_bytes_(header_bytes),
record_bytes_(record_bytes),
footer_bytes_(footer_bytes),
env_(env),
file_pos_limit_(-1),
record_number_(0) {}
// On success:
// * input_buffer_ != nullptr,
// * input_buffer_->Tell() == footer_bytes_
// * file_pos_limit_ == file size - header_bytes_
Status OnWorkStartedLocked() override {
record_number_ = 0;
uint64 file_size = 0;
TF_RETURN_IF_ERROR(env_->GetFileSize(current_work(), &file_size));
file_pos_limit_ = file_size - footer_bytes_;
RandomAccessFile* file = nullptr;
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
input_buffer_.reset(new io::InputBuffer(file, kBufferSize));
TF_RETURN_IF_ERROR(input_buffer_->SkipNBytes(header_bytes_));
return Status::OK();
}
Status OnWorkFinishedLocked() override {
input_buffer_.reset(nullptr);
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
bool* at_end) override {
if (input_buffer_->Tell() >= file_pos_limit_) {
*at_end = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(input_buffer_->ReadNBytes(record_bytes_, value));
*key = strings::StrCat(current_work(), ":", record_number_);
*produced = true;
++record_number_;
return Status::OK();
}
Status ResetLocked() override {
file_pos_limit_ = -1;
record_number_ = 0;
input_buffer_.reset(nullptr);
return ReaderBase::ResetLocked();
}
// TODO(josh11b): Implement serializing and restoring the state.
private:
enum { kBufferSize = 256 << 10 /* 256 kB */ };
const int64 header_bytes_;
const int64 record_bytes_;
const int64 footer_bytes_;
Env* const env_;
int64 file_pos_limit_;
int64 record_number_;
std::unique_ptr<io::InputBuffer> input_buffer_;
};
class FixedLengthRecordReaderOp : public ReaderOpKernel {
public:
explicit FixedLengthRecordReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1;
OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes));
OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes));
OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes));
OP_REQUIRES(context, header_bytes >= 0,
errors::InvalidArgument("header_bytes must be >= 0 not ",
header_bytes));
OP_REQUIRES(context, record_bytes >= 0,
errors::InvalidArgument("record_bytes must be >= 0 not ",
record_bytes));
OP_REQUIRES(context, footer_bytes >= 0,
errors::InvalidArgument("footer_bytes must be >= 0 not ",
footer_bytes));
Env* env = context->env();
SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, env]() {
return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
footer_bytes, env);
});
}
};
REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReader").Device(DEVICE_CPU),
FixedLengthRecordReaderOp);
} // namespace tensorflow
|