aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/kernels/csv_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/kernels/csv_dataset_op.cc')
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc153
1 files changed, 136 insertions, 17 deletions
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 4657807785..f7e3ed886c 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -18,7 +18,10 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
+#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
namespace {
@@ -37,6 +40,10 @@ class CSVDatasetOp : public DatasetOpKernel {
ctx, filenames_tensor->dims() <= 1,
errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+ string compression_type;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
+ &compression_type));
+
OpInputList record_defaults_list;
OP_REQUIRES_OK(ctx,
ctx->input_list("record_defaults", &record_defaults_list));
@@ -86,6 +93,19 @@ class CSVDatasetOp : public DatasetOpKernel {
filenames.push_back(filenames_tensor->flat<string>()(i));
}
+ io::ZlibCompressionOptions zlib_compression_options =
+ io::ZlibCompressionOptions::DEFAULT();
+ if (compression_type == "ZLIB") {
+ zlib_compression_options = io::ZlibCompressionOptions::DEFAULT();
+ } else if (compression_type == "GZIP") {
+ zlib_compression_options = io::ZlibCompressionOptions::GZIP();
+ } else {
+ OP_REQUIRES(ctx, compression_type.empty(),
+ errors::InvalidArgument(
+ "Unsupported compression_type: ", compression_type, "."));
+ }
+ zlib_compression_options.input_buffer_size = buffer_size;
+
std::vector<int64> select_cols;
select_cols.reserve(select_cols_tensor->NumElements());
for (int i = 0; i < select_cols_tensor->NumElements(); ++i) {
@@ -103,7 +123,8 @@ class CSVDatasetOp : public DatasetOpKernel {
ctx, select_cols.empty() || select_cols.front() >= 0,
errors::InvalidArgument("select_cols should be non-negative indices"));
- *output = new Dataset(ctx, std::move(filenames), header, buffer_size,
+ *output = new Dataset(ctx, std::move(filenames), header,
+ std::move(compression_type), zlib_compression_options,
output_types_, output_shapes_,
std::move(record_defaults), std::move(select_cols),
use_quote_delim, delim[0], std::move(na_value));
@@ -113,21 +134,24 @@ class CSVDatasetOp : public DatasetOpKernel {
class Dataset : public GraphDatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
- int64 buffer_size, const DataTypeVector& output_types,
+ string compression_type, io::ZlibCompressionOptions options,
+ const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
bool use_quote_delim, char delim, string na_value)
: GraphDatasetBase(ctx),
filenames_(std::move(filenames)),
header_(header),
- buffer_size_(buffer_size),
out_type_(output_types),
output_shapes_(output_shapes),
record_defaults_(std::move(record_defaults)),
select_cols_(std::move(select_cols)),
use_quote_delim_(use_quote_delim),
delim_(delim),
- na_value_(std::move(na_value)) {}
+ na_value_(std::move(na_value)),
+ use_compression_(!compression_type.empty()),
+ compression_type_(std::move(compression_type)),
+ options_(options) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
@@ -146,10 +170,45 @@ class CSVDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
- // TODO(rachelim): Implement this
- std::vector<Node*> input_tensors;
- TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
- return errors::Unimplemented("CSVDataset: AsGraphDefInternal");
+ Node* filenames = nullptr;
+ Node* compression_type = nullptr;
+ Node* buffer_size = nullptr;
+ Node* header = nullptr;
+ Node* delim = nullptr;
+ Node* use_quote_delim = nullptr;
+ Node* na_value = nullptr;
+ Node* select_cols = nullptr;
+
+ std::vector<Node*> record_defaults;
+ record_defaults.reserve(record_defaults_.size());
+ for (const Tensor& t : record_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ record_defaults.emplace_back(node);
+ }
+
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(options_.input_buffer_size, &buffer_size));
+ TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
+
+ string delim_string(1, delim_);
+ TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
+ TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {std::make_pair(0, filenames), std::make_pair(1, compression_type),
+ std::make_pair(2, buffer_size), std::make_pair(3, header),
+ std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
+ std::make_pair(6, na_value),
+ std::make_pair(7, select_cols)}, // Single tensor inputs
+ {std::make_pair(8, record_defaults)}, // Tensor list inputs
+ {}, output));
+ return Status::OK();
}
private:
@@ -201,14 +260,58 @@ class CSVDatasetOp : public DatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- // TODO(rachelim): Implement save
- return errors::Unimplemented("CSVDataset: SaveInternal");
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
+ current_file_index_));
+ // `input_stream_` is empty if
+ // 1. GetNext has not been called even once.
+ // 2. All files have been read and the iterator has been exhausted.
+ if (input_stream_ && num_buffer_reads_ > 0) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
+ // If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
+ num_buffer_reads_));
+ }
+ return Status::OK();
}
+
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
- // TODO(rachelim): Implement restore
- return errors::Unimplemented("CSVDataset: RestoreInternal");
+ ResetStreamsLocked();
+ int64 current_file_index;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
+ &current_file_index));
+ current_file_index_ = size_t(current_file_index);
+ // The keys "pos" and "num_buffer_reads" are written only if
+ // the iterator was saved with an open, partially read file.
+ if (reader->Contains(full_name("pos"))) {
+ int64 pos, num_buffer_reads;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
+ &num_buffer_reads));
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+
+ num_buffer_reads_ = size_t(num_buffer_reads - 1);
+
+ // Restores the most recently held buffer
+ Status s = input_stream_->SkipNBytes(
+ num_buffer_reads_ * dataset()->options_.input_buffer_size);
+ if (!s.ok() && !errors::IsOutOfRange(s)) {
+ // We might get out of range error here if the size of the file
+ // is not an exact multiple of the buffer size, and the last buffer
+ // read is < buffer_size. This is valid and we do not surface the
+ // error.
+ return s;
+ }
+
+ Status s2 = FillBuffer(&buffer_);
+ if (!s2.ok() && !errors::IsOutOfRange(s2)) {
+ return s2;
+ }
+ pos_ = size_t(pos);
+ }
+ return Status::OK();
}
private:
@@ -510,7 +613,9 @@ class CSVDatasetOp : public DatasetOpKernel {
Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
result->clear();
- Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result);
+ ++num_buffer_reads_;
+ Status s = input_stream_->ReadNBytes(
+ dataset()->options_.input_buffer_size, result);
if (errors::IsOutOfRange(s) && !result->empty()) {
// Ignore OutOfRange error when ReadNBytes read < N bytes.
@@ -675,10 +780,20 @@ class CSVDatasetOp : public DatasetOpKernel {
// Actually move on to next file.
TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
dataset()->filenames_[current_file_index_], &file_));
- input_stream_.reset(
- new io::RandomAccessInputStream(file_.get(), false));
+ random_access_input_stream_ =
+ std::make_shared<io::RandomAccessInputStream>(file_.get(), false);
+
+ if (dataset()->use_compression_) {
+ input_stream_ = std::make_shared<io::ZlibInputStream>(
+ random_access_input_stream_.get(),
+ dataset()->options_.input_buffer_size,
+ dataset()->options_.input_buffer_size, dataset()->options_);
+ } else {
+ input_stream_ = random_access_input_stream_;
+ }
buffer_.clear();
pos_ = 0;
+ num_buffer_reads_ = 0;
if (dataset()->header_) {
// Read one line, but don't include it. Pass nullptrs as dummy
// pointers to objects that shouldn't be invoked anyway
@@ -704,8 +819,10 @@ class CSVDatasetOp : public DatasetOpKernel {
string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
size_t pos_ GUARDED_BY(
mu_); // Index into the buffer must be maintained between iters
- std::unique_ptr<io::RandomAccessInputStream> input_stream_
+ size_t num_buffer_reads_ GUARDED_BY(mu_);
+ std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
GUARDED_BY(mu_);
+ std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
size_t current_file_index_ GUARDED_BY(mu_) = 0;
std::unique_ptr<RandomAccessFile> file_
GUARDED_BY(mu_); // must outlive input_stream_
@@ -713,7 +830,6 @@ class CSVDatasetOp : public DatasetOpKernel {
const std::vector<string> filenames_;
const bool header_;
- const int64 buffer_size_;
const DataTypeVector out_type_;
const std::vector<PartialTensorShape> output_shapes_;
const std::vector<Tensor> record_defaults_;
@@ -721,6 +837,9 @@ class CSVDatasetOp : public DatasetOpKernel {
const bool use_quote_delim_;
const char delim_;
const string na_value_;
+ const bool use_compression_;
+ const string compression_type_;
+ const io::ZlibCompressionOptions options_;
}; // class Dataset
DataTypeVector output_types_;