diff options
Diffstat (limited to 'tensorflow/core/kernels/tf_record_reader_op.cc')
-rw-r--r-- | tensorflow/core/kernels/tf_record_reader_op.cc | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc index 9604d44567..3cbb334141 100644 --- a/tensorflow/core/kernels/tf_record_reader_op.cc +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -27,17 +27,25 @@ namespace tensorflow { class TFRecordReader : public ReaderBase { public: - TFRecordReader(const string& node_name, Env* env) + TFRecordReader(const string& node_name, const string& compression_type, + Env* env) : ReaderBase(strings::StrCat("TFRecordReader '", node_name, "'")), env_(env), - offset_(0) {} + offset_(0), + compression_type_(compression_type) {} Status OnWorkStartedLocked() override { offset_ = 0; RandomAccessFile* file = nullptr; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file)); file_.reset(file); - reader_.reset(new io::RecordReader(file)); + + io::RecordReaderOptions options; + if (compression_type_ == "ZLIB") { + options.compression_type = io::RecordReaderOptions::ZLIB_COMPRESSION; + } + + reader_.reset(new io::RecordReader(file, options)); return Status::OK(); } @@ -74,6 +82,7 @@ class TFRecordReader : public ReaderBase { uint64 offset_; std::unique_ptr<RandomAccessFile> file_; std::unique_ptr<io::RecordReader> reader_; + string compression_type_ = ""; }; class TFRecordReaderOp : public ReaderOpKernel { @@ -81,7 +90,13 @@ class TFRecordReaderOp : public ReaderOpKernel { explicit TFRecordReaderOp(OpKernelConstruction* context) : ReaderOpKernel(context) { Env* env = context->env(); - SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); }); + + string compression_type; + context->GetAttr("compression_type", &compression_type); + + SetReaderFactory([this, compression_type, env]() { + return new TFRecordReader(name(), compression_type, env); + }); } }; |