aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc542
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py292
-rw-r--r--tensorflow/core/lib/strings/numbers.cc26
-rw-r--r--tensorflow/core/lib/strings/numbers.h2
5 files changed, 203 insertions, 660 deletions
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index e88ad3dc32..97cc0bc6c9 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
namespace tensorflow {
@@ -102,11 +103,12 @@ class CSVDatasetOp : public DatasetOpKernel {
OP_REQUIRES(
ctx, select_cols.empty() || select_cols.front() >= 0,
errors::InvalidArgument("select_cols should be non-negative indices"));
+ bool select_all_cols = select_cols.empty();
- *output = new Dataset(ctx, std::move(filenames), header, buffer_size,
- output_types_, output_shapes_,
- std::move(record_defaults), std::move(select_cols),
- use_quote_delim, delim[0], std::move(na_value));
+ *output = new Dataset(
+ ctx, std::move(filenames), header, buffer_size, output_types_,
+ output_shapes_, std::move(record_defaults), std::move(select_cols),
+ select_all_cols, use_quote_delim, delim[0], std::move(na_value));
}
private:
@@ -116,7 +118,8 @@ class CSVDatasetOp : public DatasetOpKernel {
int64 buffer_size, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
- bool use_quote_delim, char delim, string na_value)
+ bool select_all_cols, bool use_quote_delim, char delim,
+ string na_value)
: GraphDatasetBase(ctx),
filenames_(std::move(filenames)),
header_(header),
@@ -125,6 +128,7 @@ class CSVDatasetOp : public DatasetOpKernel {
output_shapes_(output_shapes),
record_defaults_(std::move(record_defaults)),
select_cols_(std::move(select_cols)),
+ select_all_cols_(select_all_cols),
use_quote_delim_(use_quote_delim),
delim_(delim),
na_value_(std::move(na_value)) {}
@@ -162,24 +166,11 @@ class CSVDatasetOp : public DatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- bool select_all = dataset()->select_cols_.empty();
do {
// We are currently processing a file, so try to read the next record
- if (input_stream_) {
- Status s = ReadRecord(ctx, out_tensors, select_all,
- dataset()->select_cols_);
- if (s.ok()) {
- // Validate output
- if (out_tensors->size() != dataset()->out_type_.size()) {
- return errors::InvalidArgument(
- "Expect ", dataset()->out_type_.size(), " fields but have ",
- out_tensors->size(), " in record");
- }
-
- *end_of_sequence = false;
- return s;
- }
- if (!errors::IsOutOfRange(s)) {
+ if (buffered_input_stream_) {
+ Status s = ReadRecord(ctx, out_tensors);
+ if (s.ok() || !errors::IsOutOfRange(s)) {
// Not at the end of file, return OK or non-EOF errors to caller.
*end_of_sequence = false;
return s;
@@ -212,341 +203,145 @@ class CSVDatasetOp : public DatasetOpKernel {
}
private:
- // Reads an entire CSV row from the input stream, either from the
- // existing buffer or by filling the buffer as needed. Converts extracted
+ // Reads a record by parsing the input buffer, and converting extracted
// fields to output tensors as we go.
- //
- // When this function is called, pos_ should be the index of the first
- // character of the record in buffer_, or past the end of the buffer.
- // Note: ctx and out_tensors are only used in this function
- // when fields are included in the record.
- Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
- bool select_all, const std::vector<int64>& selected)
+ Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (pos_ >= buffer_.size()) {
- // At the end of the file, this will return errors::OutOfRange
- TF_RETURN_IF_ERROR(FillBuffer(&buffer_));
- pos_ = 0;
- }
-
- // The first character may be \n if this is the continuation of a
- // \r\n linebreak between this and the previous record. If so, skip it.
-
- bool end_of_record = false; // Keep track of when we find \n, \r or EOF
- size_t num_parsed = 0;
- size_t num_selected_parsed = 0;
-
- Status result = Status::OK();
-
- while (!end_of_record) { // Read till we reach \n, \r or EOF
- bool include =
- select_all || (num_selected_parsed < selected.size() &&
- selected[num_selected_parsed] == num_parsed);
-
- // Don't fail fast, so that the next call to GetNext may still return
- // a valid record
- result.Update(
- ParseOneField(ctx, out_tensors, &end_of_record, include));
-
- num_parsed++;
- if (include) num_selected_parsed++;
- }
-
- return result;
- }
-
- // Parses one field from position pos_ in the buffer. Fields are
- // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of
- // the next field.
- Status ParseOneField(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_record, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (pos_ >= buffer_.size()) {
- // If we get here, this means the previous field's end coincided
- // with the end of the buffer. We can fill the buffer without abandon.
- Status s = FillBuffer(&buffer_);
-
- if (errors::IsOutOfRange(s)) {
- // Reached EOF, and last field is empty
- *end_of_record = true;
- if (include) {
- return FieldToOutput(ctx, StringPiece(), out_tensors);
- } else {
- return Status::OK();
- }
- } else if (!s.ok()) {
- return s; // Surface other errors back to caller
+ // Extracts fields from line(s) from the buffered input stream.
+ out_tensors->reserve(dataset()->record_defaults_.size());
+
+ string input;
+ TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input));
+
+ size_t current_idx = 0;
+ size_t num_fields_parsed = 0;
+ size_t selector_idx = 0; // Keep track of index into select_cols
+
+ while (current_idx < input.size()) {
+ // In each iteration, parse one field
+ if (input[current_idx] == '\n' || input[current_idx] == '\r') {
+ // This should never happen, because buffered input reader splits
+ // input on newlines.
+ return errors::InvalidArgument("Parsing error.");
}
- pos_ = 0;
- }
-
- if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') {
- return ParseQuotedField(ctx, out_tensors, end_of_record, include);
- }
-
- return ParseUnquotedField(ctx, out_tensors, end_of_record, include);
- }
-
- // For keeping track of relevant parts of a field from a previous buffer
- struct Piece {
- size_t start;
- size_t len;
- string buffer;
-
- Piece(string buffer, size_t start, size_t len)
- : start(start), len(len), buffer(std::move(buffer)) {}
- };
-
- // Given that pos_ exceeds the buffer, saves the relevant part of the
- // current buffer (if necessary), fills the buffer, and resets indices to
- // 0.
- Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
- size_t* start, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- string temp_buffer;
-
- buffer_.swap(temp_buffer);
- if (include && pos_ > *start) {
- earlier_pieces->push_back(
- Piece(std::move(temp_buffer), *start, pos_ - *start));
- }
- pos_ = 0;
- *start = 0;
- return FillBuffer(&buffer_);
- }
+ bool quoted = false;
+ bool include =
+ (dataset()->select_all_cols_ ||
+ dataset()->select_cols_[selector_idx] == num_fields_parsed);
- // Parses unquoted field from position pos_ in the buffer. Continually
- // reads from buffer until end of field is reached (delim, CRLF, or EOF).
- // Advances pos_ to keep track of our position in the buffer as we go,
- // stopping at the first character of the next field.
- Status ParseQuotedField(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_record, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- std::vector<Piece> earlier_pieces;
- size_t start = pos_;
- pos_++; // Starting quotation mark
-
- while (true) { // Each iter reads 1 char, filling buffer if necessary
- if (pos_ >= buffer_.size()) {
- Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
- if (errors::IsOutOfRange(s)) {
- return errors::InvalidArgument(
- "Reached end of file without closing quoted field in "
- "record");
- } else if (!s.ok()) {
- return s; // Surface all other errors to caller
- }
+ if (dataset()->use_quote_delim_ && input[current_idx] == '"') {
+ quoted = true;
+ current_idx++;
}
- char ch = buffer_[pos_];
- if (ch == '"') {
- // When we encounter a quote, we look ahead to the next character to
- // decide what to do
- pos_++;
- if (pos_ >= buffer_.size()) {
- Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
- if (errors::IsOutOfRange(s)) {
- // This was the last field. We are done
- *end_of_record = true;
- return QuotedFieldToOutput(ctx, StringPiece(), out_tensors,
- earlier_pieces, include);
- } else if (!s.ok()) {
- return s;
+ // Parse the body of the field
+ string field;
+ if (!quoted) {
+ while (current_idx < input.size() &&
+ input[current_idx] != dataset()->delim_) {
+ if ((dataset()->use_quote_delim_ && input[current_idx] == '"') ||
+ input[current_idx] == '\n' || input[current_idx] == '\r') {
+ return errors::InvalidArgument(
+ "Unquoted fields cannot have quotes/CRLFs inside");
}
- }
-
- char next = buffer_[pos_];
- pos_++;
- if (next == dataset()->delim_) {
- return QuotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
- out_tensors, earlier_pieces, include);
-
- } else if (next == '\n' || next == '\r') {
- *end_of_record = true;
- Status s = QuotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
- out_tensors, earlier_pieces, include);
- if (next == '\r') SkipNewLineIfNecessary();
- return s;
- } else if (next != '"') {
- return errors::InvalidArgument(
- "Quote inside a string has to be escaped by another quote");
- }
+ if (include) field += input[current_idx];
+ current_idx++;
+ } // Exit condition: end of input, or current index at delim
+ // Go to next field or the end
+ current_idx++;
} else {
- pos_++;
- }
- }
- }
-
- // Converts quoted field to an output tensor, removing the starting
- // and ending quotes from it and unescaping double quotations if
- // necessary.
- Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field,
- std::vector<Tensor>* out_tensors,
- const std::vector<Piece>& earlier_pieces,
- bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!include) return Status::OK();
-
- if (earlier_pieces.empty()) {
- if (field.find('\"', 1) == field.size() - 1) {
- // `field` contains no escaped quotation marks.
- // Exclude framing quotation marks
- field.remove_prefix(1);
- field.remove_suffix(1);
- return FieldToOutput(ctx, field, out_tensors);
- }
- }
- string field_complete;
- size_t str_len = field.size();
- for (const Piece& p : earlier_pieces) {
- str_len += p.len;
- }
- field_complete.reserve(str_len);
-
- // This bool flips every time we see a quote, so that we skip the second
- // quote of every pair of adjacent quotes in the field. We need to track
- // this across iterations of the for loop because adjacent double quotes
- // may be in different buffers. Initialize to true because we also skip
- // the opening quotation mark of the quoted field.
- bool skip_next_quote = true;
- for (const Piece& p : earlier_pieces) {
- AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len),
- &field_complete, &skip_next_quote);
- }
- AppendUnescapedPiece(field, &field_complete, &skip_next_quote);
- StringPiece result = StringPiece(field_complete);
- result.remove_suffix(1); // Skip final quote
-
- return FieldToOutput(ctx, result, out_tensors);
- }
-
- void AppendUnescapedPiece(StringPiece piece, string* field_complete,
- bool* skip_next_quote) {
- size_t from = 0;
- size_t found = piece.find('\"', from);
- while (found != string::npos) {
- if (!*skip_next_quote) {
- // This is the first quote in a pair of adjacent double quotes
- field_complete->append(piece.data() + from, found + 1 - from);
- }
- *skip_next_quote = !*skip_next_quote;
- from = found + 1;
- found = piece.find('\"', from);
- }
- // Include the chunk after the last quotation mark in the string
- if (from < piece.size()) {
- field_complete->append(piece.data() + from, piece.size() - from);
- }
- }
-
- // Parses unquoted field from position pos_ in the buffer. Continually
- // reads from buffer until end of field is reached (delim, CRLF, or EOF).
- // Advances pos_ to keep track of our position in the buffer as we go,
- // stopping at the first character of the next field.
- Status ParseUnquotedField(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_record, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- std::vector<Piece> earlier_pieces;
- size_t start = pos_;
- while (true) { // Each iter reads 1 char, filling buffer if necessary
- if (pos_ >= buffer_.size()) {
- Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
- // Handle errors
- if (errors::IsOutOfRange(s)) {
- // Whatever we have is the last field of the last record
- *end_of_record = true;
- return UnquotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include);
- } else if (!s.ok()) {
- return s; // Surface all other errors to caller
+ // Quoted field needs to be ended with '"' and delim or end
+ while (true) {
+ if (current_idx >= input.size() - 1 || input.empty()) {
+ if (current_idx == input.size() - 1 &&
+ input[current_idx] == '"') {
+ // We're at the end of the input, and the quote terminates the
+ // record. Go to end.
+ current_idx++;
+ break;
+ }
+ // If there's no terminating quote, it means our buffered record
+ // line reader split a record up. This can happen if there is a
+ // newline encased in quotes. The next line is also part of the
+ // record, so we read it and reset the index.
+ if (include && current_idx == input.size() - 1) {
+ // TODO(rachelim): Instead of building up a string, keep track
+ // of terminal indices (or starting char* and length)
+ // Also look into using /lib/strings/Scanner
+ field += input[current_idx];
+ }
+ if (include) {
+ field += '\n';
+ }
+ current_idx = 0;
+ Status s = buffered_input_stream_->ReadLine(&input);
+ if (!s.ok()) {
+ return errors::InvalidArgument(
+ "Quoted field has to end with quote followed by delim, "
+ "CRLF, or EOF");
+ }
+ } else if (input[current_idx] == '"' &&
+ input[current_idx + 1] == dataset()->delim_) {
+ // End of field, go to next field or end
+ current_idx += 2;
+ break;
+ } else if (input[current_idx] == '"') {
+ // Current char is a quote. Since we're not at end of field,
+ // the next character must also be a quote.
+ if (input[current_idx + 1] != '"') {
+ return errors::InvalidArgument(
+ "Quote inside a string has to be escaped by another "
+ "quote");
+ }
+ if (include) field += '"';
+ current_idx += 2;
+ } else {
+ if (include) field += input[current_idx];
+ current_idx++;
+ }
}
}
- char ch = buffer_[pos_];
+ num_fields_parsed++;
- if (ch == dataset()->delim_) {
- Status s = UnquotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include);
- pos_++;
- return s;
- }
- if (ch == '\n' || ch == '\r') {
- // need special case to skip over first \n of record if the line
- // breaks are \r\n
- Status s = UnquotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include);
- *end_of_record = true;
- pos_++;
- if (ch == '\r') SkipNewLineIfNecessary();
- return s;
- }
- if (dataset()->use_quote_delim_ && ch == '"') {
- // Advance pos_ to the next field anyway so that we can ignore
- // errors gracefully if required. The caller of this will be able to
- // call ParseOneField and continue with the rest of the record.
- AdvanceToNextField(end_of_record);
- return errors::InvalidArgument(
- "Unquoted fields cannot have quotes inside");
+ if (include) {
+ // Add the tensor to the result
+ TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field),
+ selector_idx, out_tensors));
+ selector_idx++;
+ // Terminate early if we have all the fields we want
+ if (selector_idx == dataset()->select_cols_.size())
+ return Status::OK();
}
- // Otherwise, go to next character
- pos_++;
+ } // Exit condition: current_idx has reached the end of record
+
+ // Check if the last field is empty, and include it if necessary
+ bool include =
+ (dataset()->select_all_cols_ ||
+ dataset()->select_cols_[selector_idx] == num_fields_parsed);
+ if (include && !input.empty() &&
+ input[input.size() - 1] == dataset()->delim_) {
+ TF_RETURN_IF_ERROR(
+ FieldToOutput(ctx, string(), selector_idx, out_tensors));
}
- }
-
- // Advances pos_ to the start of the next field, as delimited by delim,
- // CRLF, or EOF, ignoring errors, and not keeping track of characters in
- // the current field.
- void AdvanceToNextField(bool* end_of_record)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- while (true) {
- if (pos_ >= buffer_.size()) {
- Status s = FillBuffer(&buffer_);
- pos_ = 0;
- if (!s.ok()) {
- *end_of_record = true;
- return;
- }
- }
- char ch = buffer_[pos_];
- pos_++;
-
- if (ch == dataset()->delim_) {
- return;
- }
-
- if (ch == '\n' || ch == '\r') {
- *end_of_record = true;
- if (ch == '\r') SkipNewLineIfNecessary();
- return;
- }
- }
- }
-
- Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- result->clear();
- Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result);
-
- if (errors::IsOutOfRange(s) && !result->empty()) {
- // Ignore OutOfRange error when ReadNBytes read < N bytes.
- return Status::OK();
+ // Check that number of fields matches
+ if (out_tensors->size() != dataset()->out_type_.size()) {
+ return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
+ " fields but have ",
+ out_tensors->size(), " in record");
}
- return s;
+ return Status::OK();
}
- // Given a field, converts it to the right output tensor type
- Status FieldToOutput(IteratorContext* ctx, StringPiece field,
+ // Given a string field, and its index in the output,
+ // converts it to a Tensor of the right type and adds it to the
+ // out_tensors vector.
+ Status FieldToOutput(IteratorContext* ctx, string field,
+ size_t output_idx,
std::vector<Tensor>* out_tensors) {
- size_t output_idx = out_tensors->size();
if (output_idx >= dataset()->out_type_.size()) {
// We can get here if we're selecting all columns, but the number of
// fields exceeds the number of defaults provided
@@ -602,7 +397,7 @@ class CSVDatasetOp : public DatasetOpKernel {
dataset()->record_defaults_[output_idx].flat<float>()(0);
} else {
float value;
- if (!strings::safe_strtof(field, &value)) {
+ if (!strings::safe_strtof(field.c_str(), &value)) {
return errors::InvalidArgument(
"Field ", output_idx,
" in record is not a valid float: ", field);
@@ -617,7 +412,7 @@ class CSVDatasetOp : public DatasetOpKernel {
dataset()->record_defaults_[output_idx].flat<double>()(0);
} else {
double value;
- if (!strings::safe_strtod(field, &value)) {
+ if (!strings::safe_strtod(field.c_str(), &value)) {
return errors::InvalidArgument(
"Field ", output_idx,
" in record is not a valid double: ", field);
@@ -631,7 +426,7 @@ class CSVDatasetOp : public DatasetOpKernel {
component.scalar<string>()() =
dataset()->record_defaults_[output_idx].flat<string>()(0);
} else {
- component.scalar<string>()() = field.ToString();
+ component.scalar<string>()() = std::move(field);
}
break;
}
@@ -644,50 +439,6 @@ class CSVDatasetOp : public DatasetOpKernel {
return Status::OK();
}
- // Records can be delimited by "\r\n" line breaks. When we encounter a
- // '\r', we have to check the next character to see if it is part of the
- // linebreak, and ignore it if so.
- void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (pos_ >= buffer_.size()) {
- Status s = FillBuffer(&buffer_);
- pos_ = 0;
- // If we failed to fill buffer, it doesn't matter because we're done
- // with the record
- if (!s.ok()) return;
- }
- if (buffer_[pos_] == '\n') {
- pos_++;
- }
- }
-
- // Given a string field, and its index in the output,
- // converts it to a Tensor of the right type and adds it to the
- // out_tensors vector.
- Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field,
- std::vector<Tensor>* out_tensors,
- const std::vector<Piece>& earlier_pieces,
- bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!include) return Status::OK();
-
- if (earlier_pieces.empty()) {
- return FieldToOutput(ctx, field, out_tensors);
- }
-
- size_t str_len = field.size();
- for (const Piece& p : earlier_pieces) {
- str_len += p.len;
- }
- string field_complete;
- field_complete.reserve(str_len);
-
- for (const Piece& p : earlier_pieces) {
- field_complete.append(p.buffer, p.start, p.len);
- }
-
- field_complete.append(field.data(), field.size());
- return FieldToOutput(ctx, field_complete, out_tensors);
- }
-
// Sets up reader streams to read from the file at `current_file_index_`.
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
@@ -701,18 +452,16 @@ class CSVDatasetOp : public DatasetOpKernel {
dataset()->filenames_[current_file_index_], &file_));
input_stream_.reset(
new io::RandomAccessInputStream(file_.get(), false));
- buffer_.clear();
- pos_ = 0;
+ // TODO(rachelim): Maintain our own buffer so we don't read every record
+ // twice
+ buffered_input_stream_.reset(new io::BufferedInputStream(
+ input_stream_.get(), dataset()->buffer_size_, false));
if (dataset()->header_) {
- // Read one line, but don't include it. Pass nullptrs as dummy
- // pointers to objects that shouldn't be invoked anyway
- // We need to process this as a record here instead of just finding
- // the first newline because it might contain quoted fields with
- // newlines in the header as well
- std::vector<int64> empty;
- Status s = ReadRecord(nullptr, nullptr, false, empty);
- if (!s.ok()) {
- return errors::InvalidArgument("Can't read header of file");
+ // Ignore header line
+ string str;
+ Status s = buffered_input_stream_->ReadLine(&str);
+ if (errors::IsOutOfRange(s)) {
+ return errors::InvalidArgument("Can't read header of empty file");
}
}
return Status::OK();
@@ -721,15 +470,15 @@ class CSVDatasetOp : public DatasetOpKernel {
// Resets all reader streams.
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
input_stream_.reset();
+ buffered_input_stream_.reset();
file_.reset();
}
mutex mu_;
- string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
- size_t pos_ GUARDED_BY(
- mu_); // Index into the buffer must be maintained between iters
std::unique_ptr<io::RandomAccessInputStream> input_stream_
GUARDED_BY(mu_);
+ std::unique_ptr<io::BufferedInputStream> buffered_input_stream_
+ GUARDED_BY(mu_);
size_t current_file_index_ GUARDED_BY(mu_) = 0;
std::unique_ptr<RandomAccessFile> file_
GUARDED_BY(mu_); // must outlive input_stream_
@@ -742,6 +491,7 @@ class CSVDatasetOp : public DatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const std::vector<Tensor> record_defaults_;
const std::vector<int64> select_cols_;
+ const bool select_all_cols_;
const bool use_quote_delim_;
const char delim_;
const string na_value_;
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 523d1f2f71..c483a43769 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -128,7 +128,6 @@ py_test(
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
- "//tensorflow/contrib/data/python/ops:error_ops",
"//tensorflow/contrib/data/python/ops:readers",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 74b90ec7d1..8c138c7081 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -25,7 +25,6 @@ import time
import numpy as np
-from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
from tensorflow.python.data.ops import readers as core_readers
@@ -62,12 +61,12 @@ class CsvDatasetOpTest(test.TestCase):
op2 = sess.run(next2)
self.assertAllEqual(op1, op2)
- def setup_files(self, inputs, linebreak='\n'):
+ def setup_files(self, inputs):
filenames = []
for i, ip in enumerate(inputs):
- fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
- with open(fn, 'wb') as f:
- f.write(linebreak.join(ip).encode('utf-8'))
+ fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i)
+ with open(fn, 'w') as f:
+ f.write('\n'.join(ip))
filenames.append(fn)
return filenames
@@ -87,47 +86,38 @@ class CsvDatasetOpTest(test.TestCase):
inputs, **kwargs)
self._assert_datasets_equal(g, dataset_actual, dataset_expected)
- def _verify_output_or_err(self,
- sess,
- dataset,
- expected_output=None,
- expected_err_re=None):
- nxt = dataset.make_one_shot_iterator().get_next()
- if expected_err_re is None:
- # Verify that output is expected, without errors
- expected_output = [[
- v.encode('utf-8') if isinstance(v, str) else v for v in op
- ] for op in expected_output]
- for value in expected_output:
- op = sess.run(nxt)
- self.assertAllEqual(op, value)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
- else:
- # Verify that OpError is produced as expected
- with self.assertRaisesOpError(expected_err_re):
- while True:
- try:
- sess.run(nxt)
- except errors.OutOfRangeError:
- break
-
def _test_dataset(self,
inputs,
expected_output=None,
expected_err_re=None,
- linebreak='\n',
**kwargs):
"""Checks that elements produced by CsvDataset match expected output."""
# Convert str type because py3 tf strings are bytestrings
- filenames = self.setup_files(inputs, linebreak)
+ filenames = self.setup_files(inputs)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, **kwargs)
- self._verify_output_or_err(sess, dataset, expected_output,
- expected_err_re)
-
- def testCsvDataset_requiredFields(self):
+ nxt = dataset.make_one_shot_iterator().get_next()
+ if expected_err_re is None:
+ # Verify that output is expected, without errors
+ expected_output = [[
+ v.encode('utf-8') if isinstance(v, str) else v for v in op
+ ] for op in expected_output]
+ for value in expected_output:
+ op = sess.run(nxt)
+ self.assertAllEqual(op, value)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(nxt)
+ else:
+ # Verify that OpError is produced as expected
+ with self.assertRaisesOpError(expected_err_re):
+ while True:
+ try:
+ sess.run(nxt)
+ except errors.OutOfRangeError:
+ break
+
+ def testCsvDataset_floatRequired(self):
record_defaults = [[]] * 4
inputs = [['1,2,3,4']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@@ -147,36 +137,10 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
- def testCsvDataset_withEmptyFields(self):
- record_defaults = [[0]] * 4
- inputs = [[',,,', '1,1,1,', ',2,2,2']]
- self._test_dataset(
- inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
- record_defaults=record_defaults)
-
- def testCsvDataset_errWithUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4']]
- self._test_dataset(
- inputs,
- expected_err_re='Unquoted fields cannot have quotes inside',
- record_defaults=record_defaults)
-
- def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']]
- filenames = self.setup_files(inputs)
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
-
- def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, use_quote_delim=False)
+ def testCsvDataset_withQuoted(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
def testCsvDataset_mixedTypes(self):
record_defaults = [
@@ -200,6 +164,11 @@ class CsvDatasetOpTest(test.TestCase):
self._test_by_comparison(
inputs, record_defaults=record_defaults, field_delim=':')
+ def testCsvDataset_withEmptyValues(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,,3,4', ',6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
def testCsvDataset_withNaValue(self):
record_defaults = [[0]] * 4
inputs = [['1,NA,3,4', 'NA,6,7,8']]
@@ -207,8 +176,8 @@ class CsvDatasetOpTest(test.TestCase):
inputs, record_defaults=record_defaults, na_value='NA')
def testCsvDataset_withSelectCols(self):
- record_defaults = [['']] * 2
- inputs = [['1,2,3,4', '"5","6","7","8"']]
+ record_defaults = [[0]] * 2
+ inputs = [['1,2,3,4', '5,6,7,8']]
self._test_by_comparison(
inputs, record_defaults=record_defaults, select_cols=[1, 2])
@@ -221,17 +190,27 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults=record_defaults,
select_cols=[3, 4])
- def testCsvDataset_withOneCol(self):
- record_defaults = [['NA']]
- inputs = [['0', '', '2']]
- self._test_dataset(
- inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
-
def testCsvDataset_withMultipleFiles(self):
record_defaults = [[0]] * 4
inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
+ def testCsvDataset_withNewLine(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_withMultipleNewLines(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
def testCsvDataset_withLeadingAndTrailingSpaces(self):
record_defaults = [[0.0]] * 4
inputs = [['0, 1, 2, 3']]
@@ -287,10 +266,9 @@ class CsvDatasetOpTest(test.TestCase):
def testCsvDataset_errorWithHeaderEmptyFile(self):
record_defaults = [[0]] * 2
inputs = [[]]
- expected_err_re = "Can't read header of file"
self._test_dataset(
inputs,
- expected_err_re=expected_err_re,
+ expected_err_re="Can't read header of empty file",
record_defaults=record_defaults,
header=True,
)
@@ -306,7 +284,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['', '1,2']] # First record is empty
self._test_dataset(
inputs,
- expected_err_re='Expect 2 fields but have 1 in record',
+ expected_err_re='Expect 2 fields but have 0 in record',
record_defaults=record_defaults)
def testCsvDataset_withChainedOps(self):
@@ -323,7 +301,7 @@ class CsvDatasetOpTest(test.TestCase):
def testCsvDataset_withTypeDefaults(self):
# Testing using dtypes as record_defaults for required fields
- record_defaults = [dtypes.float32, [0.0]]
+ record_defaults = [dtypes.float32, dtypes.float32]
inputs = [['1.0,2.0', '3.0,4.0']]
self._test_dataset(
inputs,
@@ -348,162 +326,6 @@ class CsvDatasetOpTest(test.TestCase):
self.assertEqual(result, sorted(result))
-## The following tests exercise parsing logic for quoted fields
-
- def testCsvDataset_withQuoted(self):
- record_defaults = [['']] * 4
- inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withOneColAndQuotes(self):
- record_defaults = [['']]
- inputs = [['"0"', '"1"', '"2"']]
- self._test_dataset(
- inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
-
- def testCsvDataset_withNewLine(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_withNewLineInUnselectedCol(self):
- record_defaults = [['']]
- inputs = [['1,"2\n3",4', '5,6,7']]
- self._test_dataset(
- inputs,
- expected_output=[['1'], ['5']],
- record_defaults=record_defaults,
- select_cols=[0])
-
- def testCsvDataset_withMultipleNewLines(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_errorWithTerminateMidRecord(self):
- record_defaults = [['']] * 4
- inputs = [['a,b,c,"a']]
- self._test_dataset(
- inputs,
- expected_err_re=
- 'Reached end of file without closing quoted field in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_withEscapedQuotes(self):
- record_defaults = [['']] * 4
- inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
-
-## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
-## and different types of line breaks
-
- def testCsvDataset_withInvalidBufferSize(self):
- record_defaults = [['']] * 4
- inputs = [['a,b,c,d']]
- self._test_dataset(
- inputs,
- expected_err_re='buffer_size should be positive',
- record_defaults=record_defaults,
- buffer_size=0)
-
- def testCsvDataset_withBufferSize(self):
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs, expected, record_defaults=record_defaults, buffer_size=i + 1)
-
- def testCsvDataset_withCR(self):
- # Test that when the line separator is '\r', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r',
- record_defaults=record_defaults,
- buffer_size=i + 1)
-
- def testCsvDataset_withCRLF(self):
- # Test that when the line separator is '\r\n', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r\n',
- record_defaults=record_defaults,
- buffer_size=i + 1)
-
- def testCsvDataset_withBufferSizeAndQuoted(self):
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\n',
- record_defaults=record_defaults,
- buffer_size=i + 1)
- self._test_dataset(
- inputs, expected, linebreak='\n', record_defaults=record_defaults)
-
- def testCsvDataset_withCRAndQuoted(self):
- # Test that when the line separator is '\r', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r',
- record_defaults=record_defaults,
- buffer_size=i + 1)
- self._test_dataset(
- inputs, expected, linebreak='\r', record_defaults=record_defaults)
-
- def testCsvDataset_withCRLFAndQuoted(self):
- # Test that when the line separator is '\r\n', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r\n',
- record_defaults=record_defaults,
- buffer_size=i + 1)
- self._test_dataset(
- inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
-
class CsvDatasetBenchmark(test.Benchmark):
"""Benchmarks for the various ways of creating a dataset from CSV files.
@@ -521,7 +343,7 @@ class CsvDatasetBenchmark(test.Benchmark):
self._filenames = []
for n in self._num_cols:
fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
- with open(fn, 'wb') as f:
+ with open(fn, 'w') as f:
# Just write 100 rows and use `repeat`... Assumes the cost
# of creating an iterator is not significant
row = ','.join([str_val for _ in range(n)])
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index f18c6dc709..987e4fe733 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -345,19 +345,6 @@ bool safe_strtof(const char* str, float* value) {
return processed_characters_count > 0;
}
-bool safe_strtof(StringPiece str, float* value) {
- int processed_characters_count = -1;
- auto len = str.size();
-
- // If string length exceeds buffer size or int max, fail.
- if (len >= kFastToBufferSize) return false;
- if (len > std::numeric_limits<int>::max()) return false;
-
- *value = StringToFloatConverter().StringToFloat(
- str.data(), static_cast<int>(len), &processed_characters_count);
- return processed_characters_count > 0;
-}
-
bool safe_strtod(const char* str, double* value) {
int processed_characters_count = -1;
auto len = str_util::Strnlen(str, kFastToBufferSize);
@@ -372,19 +359,6 @@ bool safe_strtod(const char* str, double* value) {
return processed_characters_count > 0;
}
-bool safe_strtod(StringPiece str, double* value) {
- int processed_characters_count = -1;
- auto len = str.size();
-
- // If string length exceeds buffer size or int max, fail.
- if (len >= kFastToBufferSize) return false;
- if (len > std::numeric_limits<int>::max()) return false;
-
- *value = StringToFloatConverter().StringToDouble(
- str.data(), static_cast<int>(len), &processed_characters_count);
- return processed_characters_count > 0;
-}
-
size_t FloatToBuffer(float value, char* buffer) {
// FLT_DIG is 6 for IEEE-754 floats, which are used on almost all
// platforms these days. Just in case some system exists where FLT_DIG
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index f62584dedb..9cb56415cb 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -116,14 +116,12 @@ bool safe_strtou64(StringPiece str, uint64* value);
// Values may be rounded on over- and underflow.
// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
bool safe_strtof(const char* str, float* value);
-bool safe_strtof(StringPiece str, float* value);
// Convert strings to double precision floating point values.
// Leading and trailing spaces are allowed.
// Values may be rounded on over- and underflow.
// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
bool safe_strtod(const char* str, double* value);
-bool safe_strtod(StringPiece str, double* value);
inline bool ProtoParseNumeric(StringPiece s, int32* value) {
return safe_strto32(s, value);