diff options
-rw-r--r-- | tensorflow/contrib/data/kernels/csv_dataset_op.cc | 542 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py | 292 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers.cc | 26 | ||||
-rw-r--r-- | tensorflow/core/lib/strings/numbers.h | 2 |
5 files changed, 203 insertions, 660 deletions
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index e88ad3dc32..97cc0bc6c9 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" namespace tensorflow { @@ -102,11 +103,12 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES( ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); + bool select_all_cols = select_cols.empty(); - *output = new Dataset(ctx, std::move(filenames), header, buffer_size, - output_types_, output_shapes_, - std::move(record_defaults), std::move(select_cols), - use_quote_delim, delim[0], std::move(na_value)); + *output = new Dataset( + ctx, std::move(filenames), header, buffer_size, output_types_, + output_shapes_, std::move(record_defaults), std::move(select_cols), + select_all_cols, use_quote_delim, delim[0], std::move(na_value)); } private: @@ -116,7 +118,8 @@ class CSVDatasetOp : public DatasetOpKernel { int64 buffer_size, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, std::vector<Tensor> record_defaults, std::vector<int64> select_cols, - bool use_quote_delim, char delim, string na_value) + bool select_all_cols, bool use_quote_delim, char delim, + string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), @@ -125,6 +128,7 @@ class CSVDatasetOp : public DatasetOpKernel { output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), + select_all_cols_(select_all_cols), use_quote_delim_(use_quote_delim), delim_(delim), na_value_(std::move(na_value)) {} @@ -162,24 +166,11 @@ class CSVDatasetOp : public DatasetOpKernel { std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - bool select_all = dataset()->select_cols_.empty(); do { // We are currently processing a file, so try to read the next record - if (input_stream_) { - Status s = ReadRecord(ctx, out_tensors, select_all, - dataset()->select_cols_); - if (s.ok()) { - // Validate output - if (out_tensors->size() != dataset()->out_type_.size()) { - return errors::InvalidArgument( - "Expect ", dataset()->out_type_.size(), " fields but have ", - out_tensors->size(), " in record"); - } - - *end_of_sequence = false; - return s; - } - if (!errors::IsOutOfRange(s)) { + if (buffered_input_stream_) { + Status s = ReadRecord(ctx, out_tensors); + if (s.ok() || !errors::IsOutOfRange(s)) { // Not at the end of file, return OK or non-EOF errors to caller. *end_of_sequence = false; return s; @@ -212,341 +203,145 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - // Reads an entire CSV row from the input stream, either from the - // existing buffer or by filling the buffer as needed. Converts extracted + // Reads a record by parsing the input buffer, and converting extracted // fields to output tensors as we go. - // - // When this function is called, pos_ should be the index of the first - // character of the record in buffer_, or past the end of the buffer. - // Note: ctx and out_tensors are only used in this function - // when fields are included in the record. - Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors, - bool select_all, const std::vector<int64>& selected) + Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - // At the end of the file, this will return errors::OutOfRange - TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); - pos_ = 0; - } - - // The first character may be \n if this is the continuation of a - // \r\n linebreak between this and the previous record. If so, skip it. - - bool end_of_record = false; // Keep track of when we find \n, \r or EOF - size_t num_parsed = 0; - size_t num_selected_parsed = 0; - - Status result = Status::OK(); - - while (!end_of_record) { // Read till we reach \n, \r or EOF - bool include = - select_all || (num_selected_parsed < selected.size() && - selected[num_selected_parsed] == num_parsed); - - // Don't fail fast, so that the next call to GetNext may still return - // a valid record - result.Update( - ParseOneField(ctx, out_tensors, &end_of_record, include)); - - num_parsed++; - if (include) num_selected_parsed++; - } - - return result; - } - - // Parses one field from position pos_ in the buffer. Fields are - // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of - // the next field. - Status ParseOneField(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - // If we get here, this means the previous field's end coincided - // with the end of the buffer. We can fill the buffer without abandon. - Status s = FillBuffer(&buffer_); - - if (errors::IsOutOfRange(s)) { - // Reached EOF, and last field is empty - *end_of_record = true; - if (include) { - return FieldToOutput(ctx, StringPiece(), out_tensors); - } else { - return Status::OK(); - } - } else if (!s.ok()) { - return s; // Surface other errors back to caller + // Extracts fields from line(s) from the buffered input stream. + out_tensors->reserve(dataset()->record_defaults_.size()); + + string input; + TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); + + size_t current_idx = 0; + size_t num_fields_parsed = 0; + size_t selector_idx = 0; // Keep track of index into select_cols + + while (current_idx < input.size()) { + // In each iteration, parse one field + if (input[current_idx] == '\n' || input[current_idx] == '\r') { + // This should never happen, because buffered input reader splits + // input on newlines. + return errors::InvalidArgument("Parsing error."); } - pos_ = 0; - } - - if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { - return ParseQuotedField(ctx, out_tensors, end_of_record, include); - } - - return ParseUnquotedField(ctx, out_tensors, end_of_record, include); - } - - // For keeping track of relevant parts of a field from a previous buffer - struct Piece { - size_t start; - size_t len; - string buffer; - - Piece(string buffer, size_t start, size_t len) - : start(start), len(len), buffer(std::move(buffer)) {} - }; - - // Given that pos_ exceeds the buffer, saves the relevant part of the - // current buffer (if necessary), fills the buffer, and resets indices to - // 0. - Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces, - size_t* start, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - string temp_buffer; - - buffer_.swap(temp_buffer); - if (include && pos_ > *start) { - earlier_pieces->push_back( - Piece(std::move(temp_buffer), *start, pos_ - *start)); - } - pos_ = 0; - *start = 0; - return FillBuffer(&buffer_); - } + bool quoted = false; + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); - // Parses unquoted field from position pos_ in the buffer. Continually - // reads from buffer until end of field is reached (delim, CRLF, or EOF). - // Advances pos_ to keep track of our position in the buffer as we go, - // stopping at the first character of the next field. - Status ParseQuotedField(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - std::vector<Piece> earlier_pieces; - size_t start = pos_; - pos_++; // Starting quotation mark - - while (true) { // Each iter reads 1 char, filling buffer if necessary - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - if (errors::IsOutOfRange(s)) { - return errors::InvalidArgument( - "Reached end of file without closing quoted field in " - "record"); - } else if (!s.ok()) { - return s; // Surface all other errors to caller - } + if (dataset()->use_quote_delim_ && input[current_idx] == '"') { + quoted = true; + current_idx++; } - char ch = buffer_[pos_]; - if (ch == '"') { - // When we encounter a quote, we look ahead to the next character to - // decide what to do - pos_++; - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - if (errors::IsOutOfRange(s)) { - // This was the last field. We are done - *end_of_record = true; - return QuotedFieldToOutput(ctx, StringPiece(), out_tensors, - earlier_pieces, include); - } else if (!s.ok()) { - return s; + // Parse the body of the field + string field; + if (!quoted) { + while (current_idx < input.size() && + input[current_idx] != dataset()->delim_) { + if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || + input[current_idx] == '\n' || input[current_idx] == '\r') { + return errors::InvalidArgument( + "Unquoted fields cannot have quotes/CRLFs inside"); } - } - - char next = buffer_[pos_]; - pos_++; - if (next == dataset()->delim_) { - return QuotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); - - } else if (next == '\n' || next == '\r') { - *end_of_record = true; - Status s = QuotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include); - if (next == '\r') SkipNewLineIfNecessary(); - return s; - } else if (next != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another quote"); - } + if (include) field += input[current_idx]; + current_idx++; + } // Exit condition: end of input, or current index at delim + // Go to next field or the end + current_idx++; } else { - pos_++; - } - } - } - - // Converts quoted field to an output tensor, removing the starting - // and ending quotes from it and unescaping double quotations if - // necessary. - Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector<Tensor>* out_tensors, - const std::vector<Piece>& earlier_pieces, - bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return Status::OK(); - - if (earlier_pieces.empty()) { - if (field.find('\"', 1) == field.size() - 1) { - // `field` contains no escaped quotation marks. - // Exclude framing quotation marks - field.remove_prefix(1); - field.remove_suffix(1); - return FieldToOutput(ctx, field, out_tensors); - } - } - string field_complete; - size_t str_len = field.size(); - for (const Piece& p : earlier_pieces) { - str_len += p.len; - } - field_complete.reserve(str_len); - - // This bool flips every time we see a quote, so that we skip the second - // quote of every pair of adjacent quotes in the field. We need to track - // this across iterations of the for loop because adjacent double quotes - // may be in different buffers. Initialize to true because we also skip - // the opening quotation mark of the quoted field. - bool skip_next_quote = true; - for (const Piece& p : earlier_pieces) { - AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), - &field_complete, &skip_next_quote); - } - AppendUnescapedPiece(field, &field_complete, &skip_next_quote); - StringPiece result = StringPiece(field_complete); - result.remove_suffix(1); // Skip final quote - - return FieldToOutput(ctx, result, out_tensors); - } - - void AppendUnescapedPiece(StringPiece piece, string* field_complete, - bool* skip_next_quote) { - size_t from = 0; - size_t found = piece.find('\"', from); - while (found != string::npos) { - if (!*skip_next_quote) { - // This is the first quote in a pair of adjacent double quotes - field_complete->append(piece.data() + from, found + 1 - from); - } - *skip_next_quote = !*skip_next_quote; - from = found + 1; - found = piece.find('\"', from); - } - // Include the chunk after the last quotation mark in the string - if (from < piece.size()) { - field_complete->append(piece.data() + from, piece.size() - from); - } - } - - // Parses unquoted field from position pos_ in the buffer. Continually - // reads from buffer until end of field is reached (delim, CRLF, or EOF). - // Advances pos_ to keep track of our position in the buffer as we go, - // stopping at the first character of the next field. - Status ParseUnquotedField(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - std::vector<Piece> earlier_pieces; - size_t start = pos_; - while (true) { // Each iter reads 1 char, filling buffer if necessary - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - // Handle errors - if (errors::IsOutOfRange(s)) { - // Whatever we have is the last field of the last record - *end_of_record = true; - return UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); - } else if (!s.ok()) { - return s; // Surface all other errors to caller + // Quoted field needs to be ended with '"' and delim or end + while (true) { + if (current_idx >= input.size() - 1 || input.empty()) { + if (current_idx == input.size() - 1 && + input[current_idx] == '"') { + // We're at the end of the input, and the quote terminates the + // record. Go to end. + current_idx++; + break; + } + // If there's no terminating quote, it means our buffered record + // line reader split a record up. This can happen if there is a + // newline encased in quotes. The next line is also part of the + // record, so we read it and reset the index. + if (include && current_idx == input.size() - 1) { + // TODO(rachelim): Instead of building up a string, keep track + // of terminal indices (or starting char* and length) + // Also look into using /lib/strings/Scanner + field += input[current_idx]; + } + if (include) { + field += '\n'; + } + current_idx = 0; + Status s = buffered_input_stream_->ReadLine(&input); + if (!s.ok()) { + return errors::InvalidArgument( + "Quoted field has to end with quote followed by delim, " + "CRLF, or EOF"); + } + } else if (input[current_idx] == '"' && + input[current_idx + 1] == dataset()->delim_) { + // End of field, go to next field or end + current_idx += 2; + break; + } else if (input[current_idx] == '"') { + // Current char is a quote. Since we're not at end of field, + // the next character must also be a quote. + if (input[current_idx + 1] != '"') { + return errors::InvalidArgument( + "Quote inside a string has to be escaped by another " + "quote"); + } + if (include) field += '"'; + current_idx += 2; + } else { + if (include) field += input[current_idx]; + current_idx++; + } } } - char ch = buffer_[pos_]; + num_fields_parsed++; - if (ch == dataset()->delim_) { - Status s = UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); - pos_++; - return s; - } - if (ch == '\n' || ch == '\r') { - // need special case to skip over first \n of record if the line - // breaks are \r\n - Status s = UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include); - *end_of_record = true; - pos_++; - if (ch == '\r') SkipNewLineIfNecessary(); - return s; - } - if (dataset()->use_quote_delim_ && ch == '"') { - // Advance pos_ to the next field anyway so that we can ignore - // errors gracefully if required. The caller of this will be able to - // call ParseOneField and continue with the rest of the record. - AdvanceToNextField(end_of_record); - return errors::InvalidArgument( - "Unquoted fields cannot have quotes inside"); + if (include) { + // Add the tensor to the result + TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), + selector_idx, out_tensors)); + selector_idx++; + // Terminate early if we have all the fields we want + if (selector_idx == dataset()->select_cols_.size()) + return Status::OK(); } - // Otherwise, go to next character - pos_++; + } // Exit condition: current_idx has reached the end of record + + // Check if the last field is empty, and include it if necessary + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); + if (include && !input.empty() && + input[input.size() - 1] == dataset()->delim_) { + TF_RETURN_IF_ERROR( + FieldToOutput(ctx, string(), selector_idx, out_tensors)); } - } - - // Advances pos_ to the start of the next field, as delimited by delim, - // CRLF, or EOF, ignoring errors, and not keeping track of characters in - // the current field. - void AdvanceToNextField(bool* end_of_record) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - while (true) { - if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); - pos_ = 0; - if (!s.ok()) { - *end_of_record = true; - return; - } - } - char ch = buffer_[pos_]; - pos_++; - - if (ch == dataset()->delim_) { - return; - } - - if (ch == '\n' || ch == '\r') { - *end_of_record = true; - if (ch == '\r') SkipNewLineIfNecessary(); - return; - } - } - } - - Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - result->clear(); - Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); - - if (errors::IsOutOfRange(s) && !result->empty()) { - // Ignore OutOfRange error when ReadNBytes read < N bytes. - return Status::OK(); + // Check that number of fields matches + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have ", + out_tensors->size(), " in record"); } - return s; + return Status::OK(); } - // Given a field, converts it to the right output tensor type - Status FieldToOutput(IteratorContext* ctx, StringPiece field, + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status FieldToOutput(IteratorContext* ctx, string field, + size_t output_idx, std::vector<Tensor>* out_tensors) { - size_t output_idx = out_tensors->size(); if (output_idx >= dataset()->out_type_.size()) { // We can get here if we're selecting all columns, but the number of // fields exceeds the number of defaults provided @@ -602,7 +397,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat<float>()(0); } else { float value; - if (!strings::safe_strtof(field, &value)) { + if (!strings::safe_strtof(field.c_str(), &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid float: ", field); @@ -617,7 +412,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat<double>()(0); } else { double value; - if (!strings::safe_strtod(field, &value)) { + if (!strings::safe_strtod(field.c_str(), &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid double: ", field); @@ -631,7 +426,7 @@ class CSVDatasetOp : public DatasetOpKernel { component.scalar<string>()() = dataset()->record_defaults_[output_idx].flat<string>()(0); } else { - component.scalar<string>()() = field.ToString(); + component.scalar<string>()() = std::move(field); } break; } @@ -644,50 +439,6 @@ class CSVDatasetOp : public DatasetOpKernel { return Status::OK(); } - // Records can be delimited by "\r\n" line breaks. When we encounter a - // '\r', we have to check the next character to see if it is part of the - // linebreak, and ignore it if so. - void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); - pos_ = 0; - // If we failed to fill buffer, it doesn't matter because we're done - // with the record - if (!s.ok()) return; - } - if (buffer_[pos_] == '\n') { - pos_++; - } - } - - // Given a string field, and its index in the output, - // converts it to a Tensor of the right type and adds it to the - // out_tensors vector. - Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector<Tensor>* out_tensors, - const std::vector<Piece>& earlier_pieces, - bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return Status::OK(); - - if (earlier_pieces.empty()) { - return FieldToOutput(ctx, field, out_tensors); - } - - size_t str_len = field.size(); - for (const Piece& p : earlier_pieces) { - str_len += p.len; - } - string field_complete; - field_complete.reserve(str_len); - - for (const Piece& p : earlier_pieces) { - field_complete.append(p.buffer, p.start, p.len); - } - - field_complete.append(field.data(), field.size()); - return FieldToOutput(ctx, field_complete, out_tensors); - } - // Sets up reader streams to read from the file at `current_file_index_`. Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (current_file_index_ >= dataset()->filenames_.size()) { @@ -701,18 +452,16 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->filenames_[current_file_index_], &file_)); input_stream_.reset( new io::RandomAccessInputStream(file_.get(), false)); - buffer_.clear(); - pos_ = 0; + // TODO(rachelim): Maintain our own buffer so we don't read every record + // twice + buffered_input_stream_.reset(new io::BufferedInputStream( + input_stream_.get(), dataset()->buffer_size_, false)); if (dataset()->header_) { - // Read one line, but don't include it. Pass nullptrs as dummy - // pointers to objects that shouldn't be invoked anyway - // We need to process this as a record here instead of just finding - // the first newline because it might contain quoted fields with - // newlines in the header as well - std::vector<int64> empty; - Status s = ReadRecord(nullptr, nullptr, false, empty); - if (!s.ok()) { - return errors::InvalidArgument("Can't read header of file"); + // Ignore header line + string str; + Status s = buffered_input_stream_->ReadLine(&str); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument("Can't read header of empty file"); } } return Status::OK(); @@ -721,15 +470,15 @@ class CSVDatasetOp : public DatasetOpKernel { // Resets all reader streams. void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { input_stream_.reset(); + buffered_input_stream_.reset(); file_.reset(); } mutex mu_; - string buffer_ GUARDED_BY(mu_); // Maintain our own buffer - size_t pos_ GUARDED_BY( - mu_); // Index into the buffer must be maintained between iters std::unique_ptr<io::RandomAccessInputStream> input_stream_ GUARDED_BY(mu_); + std::unique_ptr<io::BufferedInputStream> buffered_input_stream_ + GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr<RandomAccessFile> file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -742,6 +491,7 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; const std::vector<Tensor> record_defaults_; const std::vector<int64> select_cols_; + const bool select_all_cols_; const bool use_quote_delim_; const char delim_; const string na_value_; diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 523d1f2f71..c483a43769 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -128,7 +128,6 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", - "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/contrib/data/python/ops:readers", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 74b90ec7d1..8c138c7081 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -25,7 +25,6 @@ import time import numpy as np -from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers @@ -62,12 +61,12 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs, linebreak='\n'): + def setup_files(self, inputs): filenames = [] for i, ip in enumerate(inputs): - fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) - with open(fn, 'wb') as f: - f.write(linebreak.join(ip).encode('utf-8')) + fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) + with open(fn, 'w') as f: + f.write('\n'.join(ip)) filenames.append(fn) return filenames @@ -87,47 +86,38 @@ class CsvDatasetOpTest(test.TestCase): inputs, **kwargs) self._assert_datasets_equal(g, dataset_actual, dataset_expected) - def _verify_output_or_err(self, - sess, - dataset, - expected_output=None, - expected_err_re=None): - nxt = dataset.make_one_shot_iterator().get_next() - if expected_err_re is None: - # Verify that output is expected, without errors - expected_output = [[ - v.encode('utf-8') if isinstance(v, str) else v for v in op - ] for op in expected_output] - for value in expected_output: - op = sess.run(nxt) - self.assertAllEqual(op, value) - with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) - else: - # Verify that OpError is produced as expected - with self.assertRaisesOpError(expected_err_re): - while True: - try: - sess.run(nxt) - except errors.OutOfRangeError: - break - def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, - linebreak='\n', **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs, linebreak) + filenames = self.setup_files(inputs) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) - self._verify_output_or_err(sess, dataset, expected_output, - expected_err_re) - - def testCsvDataset_requiredFields(self): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + + def testCsvDataset_floatRequired(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) @@ -147,36 +137,10 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withEmptyFields(self): - record_defaults = [[0]] * 4 - inputs = [[',,,', '1,1,1,', ',2,2,2']] - self._test_dataset( - inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], - record_defaults=record_defaults) - - def testCsvDataset_errWithUnquotedQuotes(self): - record_defaults = [['']] * 3 - inputs = [['1,2"3,4']] - self._test_dataset( - inputs, - expected_err_re='Unquoted fields cannot have quotes inside', - record_defaults=record_defaults) - - def testCsvDataset_ignoreErrWithUnquotedQuotes(self): - record_defaults = [['']] * 3 - inputs = [['1,2"3,4', 'a,b,c"d', 'e,f,g']] - filenames = self.setup_files(inputs) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) - - def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): - record_defaults = [['']] * 3 - inputs = [['1,2"3,4']] - self._test_by_comparison( - inputs, record_defaults=record_defaults, use_quote_delim=False) + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) def testCsvDataset_mixedTypes(self): record_defaults = [ @@ -200,6 +164,11 @@ class CsvDatasetOpTest(test.TestCase): self._test_by_comparison( inputs, record_defaults=record_defaults, field_delim=':') + def testCsvDataset_withEmptyValues(self): + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', ',6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] @@ -207,8 +176,8 @@ class CsvDatasetOpTest(test.TestCase): inputs, record_defaults=record_defaults, na_value='NA') def testCsvDataset_withSelectCols(self): - record_defaults = [['']] * 2 - inputs = [['1,2,3,4', '"5","6","7","8"']] + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] self._test_by_comparison( inputs, record_defaults=record_defaults, select_cols=[1, 2]) @@ -221,17 +190,27 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, select_cols=[3, 4]) - def testCsvDataset_withOneCol(self): - record_defaults = [['NA']] - inputs = [['0', '', '2']] - self._test_dataset( - inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) - def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] @@ -287,10 +266,9 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] - expected_err_re = "Can't read header of file" self._test_dataset( inputs, - expected_err_re=expected_err_re, + expected_err_re="Can't read header of empty file", record_defaults=record_defaults, header=True, ) @@ -306,7 +284,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['', '1,2']] # First record is empty self._test_dataset( inputs, - expected_err_re='Expect 2 fields but have 1 in record', + expected_err_re='Expect 2 fields but have 0 in record', record_defaults=record_defaults) def testCsvDataset_withChainedOps(self): @@ -323,7 +301,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields - record_defaults = [dtypes.float32, [0.0]] + record_defaults = [dtypes.float32, dtypes.float32] inputs = [['1.0,2.0', '3.0,4.0']] self._test_dataset( inputs, @@ -348,162 +326,6 @@ class CsvDatasetOpTest(test.TestCase): self.assertEqual(result, sorted(result)) -## The following tests exercise parsing logic for quoted fields - - def testCsvDataset_withQuoted(self): - record_defaults = [['']] * 4 - inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - - def testCsvDataset_withOneColAndQuotes(self): - record_defaults = [['']] - inputs = [['"0"', '"1"', '"2"']] - self._test_dataset( - inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) - - def testCsvDataset_withNewLine(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_withNewLineInUnselectedCol(self): - record_defaults = [['']] - inputs = [['1,"2\n3",4', '5,6,7']] - self._test_dataset( - inputs, - expected_output=[['1'], ['5']], - record_defaults=record_defaults, - select_cols=[0]) - - def testCsvDataset_withMultipleNewLines(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_errorWithTerminateMidRecord(self): - record_defaults = [['']] * 4 - inputs = [['a,b,c,"a']] - self._test_dataset( - inputs, - expected_err_re= - 'Reached end of file without closing quoted field in record', - record_defaults=record_defaults) - - def testCsvDataset_withEscapedQuotes(self): - record_defaults = [['']] * 4 - inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - - -## Testing that parsing works with all buffer sizes, quoted/unquoted fields, -## and different types of line breaks - - def testCsvDataset_withInvalidBufferSize(self): - record_defaults = [['']] * 4 - inputs = [['a,b,c,d']] - self._test_dataset( - inputs, - expected_err_re='buffer_size should be positive', - record_defaults=record_defaults, - buffer_size=0) - - def testCsvDataset_withBufferSize(self): - record_defaults = [['NA']] * 3 - inputs = [['abc,def,ghi', '0,1,2', ',,']] - expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) - - def testCsvDataset_withCR(self): - # Test that when the line separator is '\r', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['abc,def,ghi', '0,1,2', ',,']] - expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) - - def testCsvDataset_withCRLF(self): - # Test that when the line separator is '\r\n', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['abc,def,ghi', '0,1,2', ',,']] - expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) - - def testCsvDataset_withBufferSizeAndQuoted(self): - record_defaults = [['NA']] * 3 - inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] - expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], - ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( - inputs, expected, linebreak='\n', record_defaults=record_defaults) - - def testCsvDataset_withCRAndQuoted(self): - # Test that when the line separator is '\r', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] - expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], - ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( - inputs, expected, linebreak='\r', record_defaults=record_defaults) - - def testCsvDataset_withCRLFAndQuoted(self): - # Test that when the line separator is '\r\n', parsing works with all buffer - # sizes - record_defaults = [['NA']] * 3 - inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] - expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], - ['NA', 'NA', 'NA']] - for i in range(20): - # Test a range of buffer sizes that should all work - self._test_dataset( - inputs, - expected, - linebreak='\r\n', - record_defaults=record_defaults, - buffer_size=i + 1) - self._test_dataset( - inputs, expected, linebreak='\r\n', record_defaults=record_defaults) - class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. @@ -521,7 +343,7 @@ class CsvDatasetBenchmark(test.Benchmark): self._filenames = [] for n in self._num_cols: fn = os.path.join(self._temp_dir, 'file%d.csv' % n) - with open(fn, 'wb') as f: + with open(fn, 'w') as f: # Just write 100 rows and use `repeat`... Assumes the cost # of creating an iterator is not significant row = ','.join([str_val for _ in range(n)]) diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc index f18c6dc709..987e4fe733 100644 --- a/tensorflow/core/lib/strings/numbers.cc +++ b/tensorflow/core/lib/strings/numbers.cc @@ -345,19 +345,6 @@ bool safe_strtof(const char* str, float* value) { return processed_characters_count > 0; } -bool safe_strtof(StringPiece str, float* value) { - int processed_characters_count = -1; - auto len = str.size(); - - // If string length exceeds buffer size or int max, fail. - if (len >= kFastToBufferSize) return false; - if (len > std::numeric_limits<int>::max()) return false; - - *value = StringToFloatConverter().StringToFloat( - str.data(), static_cast<int>(len), &processed_characters_count); - return processed_characters_count > 0; -} - bool safe_strtod(const char* str, double* value) { int processed_characters_count = -1; auto len = str_util::Strnlen(str, kFastToBufferSize); @@ -372,19 +359,6 @@ bool safe_strtod(const char* str, double* value) { return processed_characters_count > 0; } -bool safe_strtod(StringPiece str, double* value) { - int processed_characters_count = -1; - auto len = str.size(); - - // If string length exceeds buffer size or int max, fail. - if (len >= kFastToBufferSize) return false; - if (len > std::numeric_limits<int>::max()) return false; - - *value = StringToFloatConverter().StringToDouble( - str.data(), static_cast<int>(len), &processed_characters_count); - return processed_characters_count > 0; -} - size_t FloatToBuffer(float value, char* buffer) { // FLT_DIG is 6 for IEEE-754 floats, which are used on almost all // platforms these days. Just in case some system exists where FLT_DIG diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h index f62584dedb..9cb56415cb 100644 --- a/tensorflow/core/lib/strings/numbers.h +++ b/tensorflow/core/lib/strings/numbers.h @@ -116,14 +116,12 @@ bool safe_strtou64(StringPiece str, uint64* value); // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtof(const char* str, float* value); -bool safe_strtof(StringPiece str, float* value); // Convert strings to double precision floating point values. // Leading and trailing spaces are allowed. // Values may be rounded on over- and underflow. // Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`. bool safe_strtod(const char* str, double* value); -bool safe_strtod(StringPiece str, double* value); inline bool ProtoParseNumeric(StringPiece s, int32* value) { return safe_strto32(s, value); |