aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py875
1 files changed, 501 insertions, 374 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 9df403ef50..851a33dfc8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -17,13 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gzip
import os
+import zlib
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -182,264 +185,363 @@ class ReadBatchFeaturesTest(
class MakeCsvDatasetTest(test.TestCase):
- COLUMN_TYPES = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
- ]
- COLUMNS = ["col%d" % i for i in range(len(COLUMN_TYPES))]
- DEFAULT_VALS = [[], [], [], [], ["NULL"]]
- DEFAULTS = [
- constant_op.constant([], dtype=dtypes.int32),
- constant_op.constant([], dtype=dtypes.int64),
- constant_op.constant([], dtype=dtypes.float32),
- constant_op.constant([], dtype=dtypes.float64),
- constant_op.constant(["NULL"], dtype=dtypes.string)
- ]
- LABEL = COLUMNS[0]
-
- def setUp(self):
- super(MakeCsvDatasetTest, self).setUp()
- self._num_files = 2
- self._num_records = 11
- self._test_filenames = self._create_files()
-
- def _csv_values(self, fileno, recordno):
- return [
- fileno,
- recordno,
- fileno * recordno * 0.5,
- fileno * recordno + 0.5,
- "record %d" % recordno if recordno % 2 == 1 else "",
- ]
+ def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
+ return readers.make_csv_dataset(
+ filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs)
- def _write_file(self, filename, rows):
- for i in range(len(rows)):
- if isinstance(rows[i], list):
- rows[i] = ",".join(str(v) if v is not None else "" for v in rows[i])
- fn = os.path.join(self.get_temp_dir(), filename)
- f = open(fn, "w")
- f.write("\n".join(rows))
- f.close()
- return fn
-
- def _create_file(self, fileno, header=True):
- rows = []
- if header:
- rows.append(self.COLUMNS)
- for recno in range(self._num_records):
- rows.append(self._csv_values(fileno, recno))
- return self._write_file("csv_file%d.csv" % fileno, rows)
-
- def _create_files(self):
+ def _setup_files(self, inputs, linebreak="\n", compression_type=None):
filenames = []
- for i in range(self._num_files):
- filenames.append(self._create_file(i))
+ for i, ip in enumerate(inputs):
+ fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i)
+ contents = linebreak.join(ip).encode("utf-8")
+ if compression_type is None:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+ filenames.append(fn)
return filenames
- def _make_csv_dataset(
- self,
- filenames,
- defaults,
- column_names=COLUMNS,
- label_name=LABEL,
- select_cols=None,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- shuffle_seed=None,
- header=True,
- na_value="",
- ):
- return readers.make_csv_dataset(
- filenames,
- batch_size=batch_size,
- column_names=column_names,
- column_defaults=defaults,
- label_name=label_name,
- num_epochs=num_epochs,
- shuffle=shuffle,
- shuffle_seed=shuffle_seed,
- header=header,
- na_value=na_value,
- select_columns=select_cols,
- )
-
- def _next_actual_batch(self, file_indices, batch_size, num_epochs, defaults):
- features = {col: list() for col in self.COLUMNS}
+ def _next_expected_batch(self, expected_output, expected_keys, batch_size,
+ num_epochs):
+ features = {k: [] for k in expected_keys}
for _ in range(num_epochs):
- for i in file_indices:
- for j in range(self._num_records):
- values = self._csv_values(i, j)
- for n, v in enumerate(values):
- if v == "": # pylint: disable=g-explicit-bool-comparison
- values[n] = defaults[n][0]
- values[-1] = values[-1].encode("utf-8")
-
- # Regroup lists by column instead of row
- for n, col in enumerate(self.COLUMNS):
- features[col].append(values[n])
- if len(list(features.values())[0]) == batch_size:
- yield features
- features = {col: list() for col in self.COLUMNS}
-
- def _run_actual_batch(self, outputs, sess):
- features, labels = sess.run(outputs)
- batch = [features[k] for k in self.COLUMNS if k != self.LABEL]
- batch.append(labels)
- return batch
-
- def _verify_records(
+ for values in expected_output:
+ for n, key in enumerate(expected_keys):
+ features[key].append(values[n])
+ if len(features[expected_keys[0]]) == batch_size:
+ yield features
+ features = {k: [] for k in expected_keys}
+ if features[expected_keys[0]]: # Leftover from the last batch
+ yield features
+
+ def _verify_output(
self,
sess,
dataset,
- file_indices,
- defaults=tuple(DEFAULT_VALS),
- label_name=LABEL,
- batch_size=1,
- num_epochs=1,
+ batch_size,
+ num_epochs,
+ label_name,
+ expected_output,
+ expected_keys,
):
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
+ nxt = dataset.make_one_shot_iterator().get_next()
- for expected_features in self._next_actual_batch(file_indices, batch_size,
- num_epochs, defaults):
- actual_features = sess.run(get_next)
+ for expected_features in self._next_expected_batch(
+ expected_output,
+ expected_keys,
+ batch_size,
+ num_epochs,
+ ):
+ actual_features = sess.run(nxt)
if label_name is not None:
expected_labels = expected_features.pop(label_name)
- # Compare labels
self.assertAllEqual(expected_labels, actual_features[1])
- actual_features = actual_features[0] # Extract features dict from tuple
+ actual_features = actual_features[0]
for k in expected_features.keys():
# Compare features
self.assertAllEqual(expected_features[k], actual_features[k])
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testMakeCSVDataset(self):
- defaults = self.DEFAULTS
-
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- # Basic test: read from file 0.
- dataset = self._make_csv_dataset(self._test_filenames[0], defaults)
- self._verify_records(sess, dataset, [0])
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- # Basic test: read from file 1.
- dataset = self._make_csv_dataset(self._test_filenames[1], defaults)
- self._verify_records(sess, dataset, [1])
+ sess.run(nxt)
+
+ def _test_dataset(self,
+ inputs,
+ expected_output,
+ expected_keys,
+ batch_size=1,
+ num_epochs=1,
+ label_name=None,
+ **kwargs):
+ """Checks that elements produced by CsvDataset match expected output."""
+ # Convert str type because py3 tf strings are bytestrings
+ filenames = self._setup_files(
+ inputs, compression_type=kwargs.get("compression_type", None))
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
- # Read from both files.
- dataset = self._make_csv_dataset(self._test_filenames, defaults)
- self._verify_records(sess, dataset, range(self._num_files))
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- # Read from both files. Exercise the `batch` and `num_epochs` parameters
- # of make_csv_dataset and make sure they work.
dataset = self._make_csv_dataset(
- self._test_filenames, defaults, batch_size=2, num_epochs=10)
- self._verify_records(
- sess, dataset, range(self._num_files), batch_size=2, num_epochs=10)
+ filenames,
+ batch_size=batch_size,
+ num_epochs=num_epochs,
+ label_name=label_name,
+ **kwargs)
+ self._verify_output(sess, dataset, batch_size, num_epochs, label_name,
+ expected_output, expected_keys)
+
+ def testMakeCSVDataset(self):
+ """Tests making a CSV dataset with keys and defaults provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withBatchSizeAndEpochs(self):
+ """Tests making a CSV dataset with keys and defaults provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=3,
+ num_epochs=10,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
- def testMakeCSVDataset_withBadColumns(self):
+ def testMakeCSVDataset_withCompressionType(self):
+ """Tests `compression_type` argument."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ for compression_type in ("GZIP", "ZLIB"):
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ compression_type=compression_type,
+ )
+
+ def testMakeCSVDataset_withBadInputs(self):
"""Tests that exception is raised when input is malformed.
"""
- dupe_columns = self.COLUMNS[:-1] + self.COLUMNS[:1]
- defaults = self.DEFAULTS
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
# Duplicate column names
with self.assertRaises(ValueError):
self._make_csv_dataset(
- self._test_filenames, defaults, column_names=dupe_columns)
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ label_name="col0",
+ column_names=column_names * 2)
# Label key not one of column names
with self.assertRaises(ValueError):
self._make_csv_dataset(
- self._test_filenames, defaults, label_name="not_a_real_label")
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ label_name="not_a_real_label",
+ column_names=column_names)
def testMakeCSVDataset_withNoLabel(self):
- """Tests that CSV datasets can be created when no label is specified.
- """
- defaults = self.DEFAULTS
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- # Read from both files. Make sure this works with no label key supplied.
- dataset = self._make_csv_dataset(
- self._test_filenames,
- defaults,
- batch_size=2,
- num_epochs=10,
- label_name=None)
- self._verify_records(
- sess,
- dataset,
- range(self._num_files),
- batch_size=2,
- num_epochs=10,
- label_name=None)
+ """Tests making a CSV dataset with no label provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
def testMakeCSVDataset_withNoHeader(self):
"""Tests that datasets can be created from CSV files with no header line.
"""
- defaults = self.DEFAULTS
- file_without_header = self._create_file(
- len(self._test_filenames), header=False)
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- file_without_header,
- defaults,
- batch_size=2,
- num_epochs=10,
- header=False,
- )
- self._verify_records(
- sess,
- dataset,
- [len(self._test_filenames)],
- batch_size=2,
- num_epochs=10,
- )
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=False,
+ column_defaults=record_defaults,
+ )
def testMakeCSVDataset_withTypes(self):
"""Tests that defaults can be a dtype instead of a Tensor for required vals.
"""
- defaults = [d for d in self.COLUMN_TYPES[:-1]]
- defaults.append(constant_op.constant(["NULL"], dtype=dtypes.string))
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(self._test_filenames, defaults)
- self._verify_records(sess, dataset, range(self._num_files))
+ record_defaults = [
+ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
+ dtypes.string
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"],
+ [
+ ",".join(x[0] for x in column_names), "10,11,12,13,14",
+ "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
def testMakeCSVDataset_withNoColNames(self):
"""Tests that datasets can be created when column names are not specified.
In that case, we should infer the column names from the header lines.
"""
- defaults = self.DEFAULTS
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- # Read from both files. Exercise the `batch` and `num_epochs` parameters
- # of make_csv_dataset and make sure they work.
- dataset = self._make_csv_dataset(
- self._test_filenames,
- defaults,
- column_names=None,
- batch_size=2,
- num_epochs=10)
- self._verify_records(
- sess, dataset, range(self._num_files), batch_size=2, num_epochs=10)
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
def testMakeCSVDataset_withTypeInferenceMismatch(self):
# Test that error is thrown when num fields doesn't match columns
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
with self.assertRaises(ValueError):
self._make_csv_dataset(
- self._test_filenames,
- column_names=self.COLUMNS + ["extra_name"],
- defaults=None,
+ filenames,
+ column_names=column_names + ["extra_name"],
+ column_defaults=None,
batch_size=2,
num_epochs=10)
@@ -448,197 +550,215 @@ class MakeCsvDatasetTest(test.TestCase):
In that case, we should infer the types from the first N records.
"""
- # Test that it works with standard test files (with header, etc)
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- self._test_filenames, defaults=None, batch_size=2, num_epochs=10)
- self._verify_records(
- sess,
- dataset,
- range(self._num_files),
- batch_size=2,
- num_epochs=10,
- defaults=[[], [], [], [], [""]])
-
- def testMakeCSVDataset_withTypeInferenceTricky(self):
- # Test on a deliberately tricky file (type changes as we read more rows, and
- # there are null values)
- fn = os.path.join(self.get_temp_dir(), "file.csv")
- expected_dtypes = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32,
- dtypes.string, dtypes.string
- ]
- col_names = ["col%d" % i for i in range(len(expected_dtypes))]
- rows = [[None, None, None, "NAN", "",
- "a"], [1, 2**31 + 1, 2**64, 123, "NAN", ""],
- ['"123"', 2, 2**64, 123.4, "NAN", '"cd,efg"']]
- expected = [[0, 0, 0, 0, "", "a"], [1, 2**31 + 1, 2**64, 123, "", ""],
- [123, 2, 2**64, 123.4, "", "cd,efg"]]
- for row in expected:
- row[-1] = row[-1].encode("utf-8") # py3 expects byte strings
- row[-2] = row[-2].encode("utf-8") # py3 expects byte strings
- self._write_file("file.csv", [col_names] + rows)
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+ expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
+ label = "col0"
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- fn,
- defaults=None,
- column_names=None,
- label_name=None,
- na_value="NAN",
- )
- features = dataset.make_one_shot_iterator().get_next()
- # Check that types match
- for i in range(len(expected_dtypes)):
- print(features["col%d" % i].dtype, expected_dtypes[i])
- assert features["col%d" % i].dtype == expected_dtypes[i]
- for i in range(len(rows)):
- assert sess.run(features) == dict(zip(col_names, expected[i]))
-
- def testMakeCSVDataset_withTypeInferenceAllTypes(self):
- # Test that we make the correct inference for all types with fallthrough
- fn = os.path.join(self.get_temp_dir(), "file.csv")
- expected_dtypes = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
- dtypes.string, dtypes.string
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ )
+
+ def testMakeCSVDataset_withTypeInferenceFallthrough(self):
+ """Tests that datasets can be created when no defaults are specified.
+
+ Tests on a deliberately tricky file.
+ """
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ ",,,,",
+ "0,0,0.0,0.0,0.0",
+ "0,%s,2.0,3e50,rabbit" % str_int32_max,
+ ",,,,",
+ ]]
+ expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"],
+ [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ )
+
+ def testMakeCSVDataset_withSelectCols(self):
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
]
- col_names = ["col%d" % i for i in range(len(expected_dtypes))]
- rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]]
- expected = [[
- 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8")
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
]]
- self._write_file("file.csv", [col_names] + rows)
+ expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- fn,
- defaults=None,
- column_names=None,
- label_name=None,
- na_value="NAN",
- )
- features = dataset.make_one_shot_iterator().get_next()
- # Check that types match
- for i in range(len(expected_dtypes)):
- self.assertAllEqual(features["col%d" % i].dtype, expected_dtypes[i])
- for i in range(len(rows)):
- self.assertAllEqual(
- sess.run(features), dict(zip(col_names, expected[i])))
+ select_cols = [1, 3, 4]
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ column_defaults=[record_defaults[i] for i in select_cols],
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can still do inference without provided defaults
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can still do column name inference
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can specify column names instead of indices
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=[column_names[i] for i in select_cols],
+ )
def testMakeCSVDataset_withSelectColsError(self):
- data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
- col_names = ["col%d" % i for i in range(5)]
- fn = self._write_file("file.csv", [col_names] + data)
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+
+ select_cols = [1, 3, 4]
+ filenames = self._setup_files(inputs)
+
with self.assertRaises(ValueError):
# Mismatch in number of defaults and number of columns selected,
# should raise an error
self._make_csv_dataset(
- fn,
- defaults=[[0]] * 5,
- column_names=col_names,
- label_name=None,
- select_cols=[1, 3])
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ select_columns=select_cols)
+
with self.assertRaises(ValueError):
# Invalid column name should raise an error
self._make_csv_dataset(
- fn,
- defaults=[[0]],
- column_names=col_names,
+ filenames,
+ batch_size=1,
+ column_defaults=[[0]],
+ column_names=column_names,
label_name=None,
- select_cols=["invalid_col_name"])
-
- def testMakeCSVDataset_withSelectCols(self):
- data = [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
- col_names = ["col%d" % i for i in range(5)]
- fn = self._write_file("file.csv", [col_names] + data)
- # If select_cols is specified, should only yield a subset of columns
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- fn,
- defaults=[[0], [0]],
- column_names=col_names,
- label_name=None,
- select_cols=[1, 3])
- expected = [[1, 3], [6, 8]]
- features = dataset.make_one_shot_iterator().get_next()
- for i in range(len(data)):
- self.assertAllEqual(
- sess.run(features),
- dict(zip([col_names[1], col_names[3]], expected[i])))
- # Can still do default inference with select_cols
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- fn,
- defaults=None,
- column_names=col_names,
- label_name=None,
- select_cols=[1, 3])
- expected = [[1, 3], [6, 8]]
- features = dataset.make_one_shot_iterator().get_next()
- for i in range(len(data)):
- self.assertAllEqual(
- sess.run(features),
- dict(zip([col_names[1], col_names[3]], expected[i])))
- # Can still do column name inference
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- fn,
- defaults=None,
- column_names=None,
- label_name=None,
- select_cols=[1, 3])
- expected = [[1, 3], [6, 8]]
- features = dataset.make_one_shot_iterator().get_next()
- for i in range(len(data)):
- self.assertAllEqual(
- sess.run(features),
- dict(zip([col_names[1], col_names[3]], expected[i])))
- # Can specify column names instead of indices
- with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- fn,
- defaults=None,
- column_names=None,
- label_name=None,
- select_cols=[col_names[1], col_names[3]])
- expected = [[1, 3], [6, 8]]
- features = dataset.make_one_shot_iterator().get_next()
- for i in range(len(data)):
- self.assertAllEqual(
- sess.run(features),
- dict(zip([col_names[1], col_names[3]], expected[i])))
+ select_columns=["invalid_col_name"])
def testMakeCSVDataset_withShuffle(self):
- total_records = self._num_files * self._num_records
- defaults = self.DEFAULTS
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ def str_series(st):
+ return ",".join(str(i) for i in range(st, st + 5))
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [
+ [",".join(x for x in column_names)
+ ] + [str_series(5 * i) for i in range(15)],
+ [",".join(x for x in column_names)] +
+ [str_series(5 * i) for i in range(15, 20)],
+ ]
+
+ filenames = self._setup_files(inputs)
+
+ total_records = 20
for batch_size in [1, 2]:
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
# Test that shuffling with the same seed produces the same result
dataset1 = self._make_csv_dataset(
- self._test_filenames,
- defaults,
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
batch_size=batch_size,
+ header=True,
shuffle=True,
- shuffle_seed=5)
+ shuffle_seed=5,
+ num_epochs=2,
+ )
dataset2 = self._make_csv_dataset(
- self._test_filenames,
- defaults,
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
batch_size=batch_size,
+ header=True,
shuffle=True,
- shuffle_seed=5)
+ shuffle_seed=5,
+ num_epochs=2,
+ )
outputs1 = dataset1.make_one_shot_iterator().get_next()
outputs2 = dataset2.make_one_shot_iterator().get_next()
for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
+ batch1 = nest.flatten(sess.run(outputs1))
+ batch2 = nest.flatten(sess.run(outputs2))
for i in range(len(batch1)):
self.assertAllEqual(batch1[i], batch2[i])
@@ -646,23 +766,31 @@ class MakeCsvDatasetTest(test.TestCase):
with self.test_session(graph=g) as sess:
# Test that shuffling with a different seed produces different results
dataset1 = self._make_csv_dataset(
- self._test_filenames,
- defaults,
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
batch_size=batch_size,
+ header=True,
shuffle=True,
- shuffle_seed=5)
+ shuffle_seed=5,
+ num_epochs=2,
+ )
dataset2 = self._make_csv_dataset(
- self._test_filenames,
- defaults,
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
batch_size=batch_size,
+ header=True,
shuffle=True,
- shuffle_seed=6)
+ shuffle_seed=6,
+ num_epochs=2,
+ )
outputs1 = dataset1.make_one_shot_iterator().get_next()
outputs2 = dataset2.make_one_shot_iterator().get_next()
all_equal = False
for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
+ batch1 = nest.flatten(sess.run(outputs1))
+ batch2 = nest.flatten(sess.run(outputs2))
for i in range(len(batch1)):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)
@@ -874,6 +1002,5 @@ class MakeTFRecordDatasetTest(
self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
seed=21345)
-
if __name__ == "__main__":
test.main()