aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/tf_record_reader_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/tf_record_reader_op.cc')
-rw-r--r--tensorflow/core/kernels/tf_record_reader_op.cc23
1 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc
index 9604d44567..3cbb334141 100644
--- a/tensorflow/core/kernels/tf_record_reader_op.cc
+++ b/tensorflow/core/kernels/tf_record_reader_op.cc
@@ -27,17 +27,25 @@ namespace tensorflow {
class TFRecordReader : public ReaderBase {
public:
- TFRecordReader(const string& node_name, Env* env)
+ TFRecordReader(const string& node_name, const string& compression_type,
+ Env* env)
: ReaderBase(strings::StrCat("TFRecordReader '", node_name, "'")),
env_(env),
- offset_(0) {}
+ offset_(0),
+ compression_type_(compression_type) {}
Status OnWorkStartedLocked() override {
offset_ = 0;
RandomAccessFile* file = nullptr;
TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file));
file_.reset(file);
- reader_.reset(new io::RecordReader(file));
+
+ io::RecordReaderOptions options;
+ if (compression_type_ == "ZLIB") {
+ options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION;
+ }
+
+ reader_.reset(new io::RecordReader(file, options));
return Status::OK();
}
@@ -74,6 +82,7 @@ class TFRecordReader : public ReaderBase {
uint64 offset_;
std::unique_ptr<RandomAccessFile> file_;
std::unique_ptr<io::RecordReader> reader_;
+ string compression_type_ = "";
};
class TFRecordReaderOp : public ReaderOpKernel {
@@ -81,7 +90,13 @@ class TFRecordReaderOp : public ReaderOpKernel {
explicit TFRecordReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
Env* env = context->env();
- SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
+
+ string compression_type;
+ context->GetAttr("compression_type", &compression_type);
+
+ SetReaderFactory([this, compression_type, env]() {
+ return new TFRecordReader(name(), compression_type, env);
+ });
}
};