diff options
Diffstat (limited to 'tensorflow/contrib/data/kernels/csv_dataset_op.cc')
-rw-r--r-- | tensorflow/contrib/data/kernels/csv_dataset_op.cc | 153 |
1 files changed, 136 insertions, 17 deletions
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 4657807785..f7e3ed886c 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,7 +18,10 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { namespace { @@ -37,6 +40,10 @@ class CSVDatasetOp : public DatasetOpKernel { ctx, filenames_tensor->dims() <= 1, errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + string compression_type; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type", + &compression_type)); + OpInputList record_defaults_list; OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults_list)); @@ -86,6 +93,19 @@ class CSVDatasetOp : public DatasetOpKernel { filenames.push_back(filenames_tensor->flat<string>()(i)); } + io::ZlibCompressionOptions zlib_compression_options = + io::ZlibCompressionOptions::DEFAULT(); + if (compression_type == "ZLIB") { + zlib_compression_options = io::ZlibCompressionOptions::DEFAULT(); + } else if (compression_type == "GZIP") { + zlib_compression_options = io::ZlibCompressionOptions::GZIP(); + } else { + OP_REQUIRES(ctx, compression_type.empty(), + errors::InvalidArgument( + "Unsupported compression_type: ", compression_type, ".")); + } + zlib_compression_options.input_buffer_size = buffer_size; + std::vector<int64> select_cols; select_cols.reserve(select_cols_tensor->NumElements()); for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { @@ -103,7 +123,8 @@ class CSVDatasetOp : public DatasetOpKernel { ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); - *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + *output = new Dataset(ctx, std::move(filenames), header, + std::move(compression_type), zlib_compression_options, output_types_, output_shapes_, std::move(record_defaults), std::move(select_cols), use_quote_delim, delim[0], std::move(na_value)); @@ -113,21 +134,24 @@ class CSVDatasetOp : public DatasetOpKernel { class Dataset : public GraphDatasetBase { public: Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header, - int64 buffer_size, const DataTypeVector& output_types, + string compression_type, io::ZlibCompressionOptions options, + const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, std::vector<Tensor> record_defaults, std::vector<int64> select_cols, bool use_quote_delim, char delim, string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), - buffer_size_(buffer_size), out_type_(output_types), output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), use_quote_delim_(use_quote_delim), delim_(delim), - na_value_(std::move(na_value)) {} + na_value_(std::move(na_value)), + use_compression_(!compression_type.empty()), + compression_type_(std::move(compression_type)), + options_(options) {} std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { @@ -146,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, Node** output) const override { - // TODO(rachelim): Implement this - std::vector<Node*> input_tensors; - TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); - return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + Node* header = nullptr; + Node* delim = nullptr; + Node* use_quote_delim = nullptr; + Node* na_value = nullptr; + Node* select_cols = nullptr; + + std::vector<Node*> record_defaults; + record_defaults.reserve(record_defaults_.size()); + for (const Tensor& t : record_defaults_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } + + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(options_.input_buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); + + string delim_string(1, delim_); + TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); + TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); + TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); + TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, filenames), std::make_pair(1, compression_type), + std::make_pair(2, buffer_size), std::make_pair(3, header), + std::make_pair(4, delim), std::make_pair(5, use_quote_delim), + std::make_pair(6, na_value), + std::make_pair(7, select_cols)}, // Single tensor inputs + {std::make_pair(8, record_defaults)}, // Tensor list inputs + {}, output)); + return Status::OK(); } private: @@ -201,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { mutex_lock l(mu_); - // TODO(rachelim): Implement save - return errors::Unimplemented("CSVDataset: SaveInternal"); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + // `input_stream_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and the iterator has been exhausted. + if (input_stream_ && num_buffer_reads_ > 0) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); + // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), + num_buffer_reads_)); + } + return Status::OK(); } + Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - // TODO(rachelim): Implement restore - return errors::Unimplemented("CSVDataset: RestoreInternal"); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + // The keys "pos" and "num_buffer_reads" are written only if + // the iterator was saved with an open, partially read file. + if (reader->Contains(full_name("pos"))) { + int64 pos, num_buffer_reads; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), + &num_buffer_reads)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + + num_buffer_reads_ = size_t(num_buffer_reads - 1); + + // Restores the most recently held buffer + Status s = input_stream_->SkipNBytes( + num_buffer_reads_ * dataset()->options_.input_buffer_size); + if (!s.ok() && !errors::IsOutOfRange(s)) { + // We might get out of range error here if the size of the file + // is not an exact multiple of the buffer size, and the last buffer + // read is < buffer_size. This is valid and we do not surface the + // error. + return s; + } + + Status s2 = FillBuffer(&buffer_); + if (!s2.ok() && !errors::IsOutOfRange(s2)) { + return s2; + } + pos_ = size_t(pos); + } + return Status::OK(); } private: @@ -510,7 +613,9 @@ class CSVDatasetOp : public DatasetOpKernel { Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); - Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + ++num_buffer_reads_; + Status s = input_stream_->ReadNBytes( + dataset()->options_.input_buffer_size, result); if (errors::IsOutOfRange(s) && !result->empty()) { // Ignore OutOfRange error when ReadNBytes read < N bytes. @@ -675,10 +780,20 @@ class CSVDatasetOp : public DatasetOpKernel { // Actually move on to next file. TF_RETURN_IF_ERROR(env->NewRandomAccessFile( dataset()->filenames_[current_file_index_], &file_)); - input_stream_.reset( - new io::RandomAccessInputStream(file_.get(), false)); + random_access_input_stream_ = + std::make_shared<io::RandomAccessInputStream>(file_.get(), false); + + if (dataset()->use_compression_) { + input_stream_ = std::make_shared<io::ZlibInputStream>( + random_access_input_stream_.get(), + dataset()->options_.input_buffer_size, + dataset()->options_.input_buffer_size, dataset()->options_); + } else { + input_stream_ = random_access_input_stream_; + } buffer_.clear(); pos_ = 0; + num_buffer_reads_ = 0; if (dataset()->header_) { // Read one line, but don't include it. Pass nullptrs as dummy // pointers to objects that shouldn't be invoked anyway @@ -704,8 +819,10 @@ class CSVDatasetOp : public DatasetOpKernel { string buffer_ GUARDED_BY(mu_); // Maintain our own buffer size_t pos_ GUARDED_BY( mu_); // Index into the buffer must be maintained between iters - std::unique_ptr<io::RandomAccessInputStream> input_stream_ + size_t num_buffer_reads_ GUARDED_BY(mu_); + std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_ GUARDED_BY(mu_); + std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -713,7 +830,6 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector<string> filenames_; const bool header_; - const int64 buffer_size_; const DataTypeVector out_type_; const std::vector<PartialTensorShape> output_shapes_; const std::vector<Tensor> record_defaults_; @@ -721,6 +837,9 @@ class CSVDatasetOp : public DatasetOpKernel { const bool use_quote_delim_; const char delim_; const string na_value_; + const bool use_compression_; + const string compression_type_; + const io::ZlibCompressionOptions options_; }; // class Dataset DataTypeVector output_types_; |