aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py143
1 files changed, 81 insertions, 62 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index df115175f5..2a0e64caeb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -18,10 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gzip
import os
import string
import tempfile
import time
+import zlib
import numpy as np
@@ -62,18 +64,29 @@ class CsvDatasetOpTest(test.TestCase):
op2 = sess.run(next2)
self.assertAllEqual(op1, op2)
- def setup_files(self, inputs, linebreak='\n'):
+ def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
for i, ip in enumerate(inputs):
fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
- with open(fn, 'wb') as f:
- f.write(linebreak.join(ip).encode('utf-8'))
+ contents = linebreak.join(ip).encode('utf-8')
+ if compression_type is None:
+ with open(fn, 'wb') as f:
+ f.write(contents)
+ elif compression_type == 'GZIP':
+ with gzip.GzipFile(fn, 'wb') as f:
+ f.write(contents)
+ elif compression_type == 'ZLIB':
+ contents = zlib.compress(contents)
+ with open(fn, 'wb') as f:
+ f.write(contents)
+ else:
+ raise ValueError('Unsupported compression_type', compression_type)
filenames.append(fn)
return filenames
def _make_test_datasets(self, inputs, **kwargs):
# Test by comparing its output to what we could get with map->decode_csv
- filenames = self.setup_files(inputs)
+ filenames = self._setup_files(inputs)
dataset_expected = core_readers.TextLineDataset(filenames)
dataset_expected = dataset_expected.map(
lambda l: parsing_ops.decode_csv(l, **kwargs))
@@ -112,15 +125,18 @@ class CsvDatasetOpTest(test.TestCase):
except errors.OutOfRangeError:
break
- def _test_dataset(self,
- inputs,
- expected_output=None,
- expected_err_re=None,
- linebreak='\n',
- **kwargs):
+ def _test_dataset(
+ self,
+ inputs,
+ expected_output=None,
+ expected_err_re=None,
+ linebreak='\n',
+ compression_type=None, # Used for both setup and parsing
+ **kwargs):
"""Checks that elements produced by CsvDataset match expected output."""
# Convert str type because py3 tf strings are bytestrings
- filenames = self.setup_files(inputs, linebreak)
+ filenames = self._setup_files(inputs, linebreak, compression_type)
+ kwargs['compression_type'] = compression_type
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, **kwargs)
@@ -174,7 +190,7 @@ class CsvDatasetOpTest(test.TestCase):
def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
- filenames = self.setup_files(inputs)
+ filenames = self._setup_files(inputs)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
@@ -184,7 +200,7 @@ class CsvDatasetOpTest(test.TestCase):
def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
- filenames = self.setup_files(inputs)
+ filenames = self._setup_files(inputs)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
@@ -355,7 +371,7 @@ class CsvDatasetOpTest(test.TestCase):
'1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
'1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
]]
- file_path = self.setup_files(data)
+ file_path = self._setup_files(data)
with ops.Graph().as_default() as g:
ds = readers.make_csv_dataset(
@@ -432,14 +448,29 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults=record_defaults,
buffer_size=0)
- def testCsvDataset_withBufferSize(self):
+ def _test_dataset_on_buffer_sizes(self,
+ inputs,
+ expected,
+ linebreak,
+ record_defaults,
+ compression_type=None,
+ num_sizes_to_test=20):
+ # Testing reading with a range of buffer sizes that should all work.
+ for i in list(range(1, 1 + num_sizes_to_test)) + [None]:
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak=linebreak,
+ compression_type=compression_type,
+ record_defaults=record_defaults,
+ buffer_size=i)
+
+ def testCsvDataset_withLF(self):
record_defaults = [['NA']] * 3
inputs = [['abc,def,ghi', '0,1,2', ',,']]
expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs, expected, record_defaults=record_defaults, buffer_size=i + 1)
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\n', record_defaults=record_defaults)
def testCsvDataset_withCR(self):
# Test that when the line separator is '\r', parsing works with all buffer
@@ -447,14 +478,8 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults = [['NA']] * 3
inputs = [['abc,def,ghi', '0,1,2', ',,']]
expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r',
- record_defaults=record_defaults,
- buffer_size=i + 1)
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r', record_defaults=record_defaults)
def testCsvDataset_withCRLF(self):
# Test that when the line separator is '\r\n', parsing works with all buffer
@@ -462,29 +487,15 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults = [['NA']] * 3
inputs = [['abc,def,ghi', '0,1,2', ',,']]
expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r\n',
- record_defaults=record_defaults,
- buffer_size=i + 1)
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
def testCsvDataset_withBufferSizeAndQuoted(self):
record_defaults = [['NA']] * 3
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\n',
- record_defaults=record_defaults,
- buffer_size=i + 1)
- self._test_dataset(
+ self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\n', record_defaults=record_defaults)
def testCsvDataset_withCRAndQuoted(self):
@@ -494,15 +505,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r',
- record_defaults=record_defaults,
- buffer_size=i + 1)
- self._test_dataset(
+ self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r', record_defaults=record_defaults)
def testCsvDataset_withCRLFAndQuoted(self):
@@ -512,17 +515,33 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
['NA', 'NA', 'NA']]
- for i in range(20):
- # Test a range of buffer sizes that should all work
- self._test_dataset(
- inputs,
- expected,
- linebreak='\r\n',
- record_defaults=record_defaults,
- buffer_size=i + 1)
- self._test_dataset(
+ self._test_dataset_on_buffer_sizes(
inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
+ def testCsvDataset_withGzipCompressionType(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ compression_type='GZIP',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withZlibCompressionType(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ compression_type='ZLIB',
+ record_defaults=record_defaults)
+
class CsvDatasetBenchmark(test.Benchmark):
"""Benchmarks for the various ways of creating a dataset from CSV files.