aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-13 14:18:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 14:26:13 -0700
commitd3458112ad5a1612ec6c77f7de4a0e0ec801e882 (patch)
treed1b265fed67f2f4a9b502ea6eb32fe7c26eab2ee /tensorflow/contrib/data
parent885cd2942ae7b6239146a3f51ec3d6948ac2b89e (diff)
Consistency in record_default shapes for tf.contrib.data.CsvDataset & tf.decode_csv:
- Modify shape assertions so that both graph and eager accept rank 0 (scalar) and rank 1 tensors as `record_defaults`, and raise an error on other shapes. - Make tests run in both graph and eager modes Fixes #22030. PiperOrigin-RevId: 212877058
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc3
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py123
4 files changed, 85 insertions, 52 deletions
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 74107d5242..21ec50fb6b 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -49,6 +49,9 @@ class CSVDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx,
ctx->input_list("record_defaults", &record_defaults_list));
for (int i = 0; i < record_defaults_list.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index ae104d55bd..ad410e17fe 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -65,7 +65,13 @@ REGISTER_OP("CSVDataset")
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
// `record_defaults` must be lists of scalars
for (size_t i = 8; i < c->num_inputs(); ++i) {
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused));
+ shape_inference::ShapeHandle v;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+ return errors::InvalidArgument(
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
+ }
}
return shape_inference::ScalarShape(c);
});
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index b3c90ded39..ba202839b2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -72,12 +72,13 @@ py_test(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 63bffd023f..f8e74e4583 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -31,38 +31,49 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class CsvDatasetOpTest(test.TestCase):
- def _assert_datasets_equal(self, g, ds1, ds2):
+ def _get_next(self, dataset):
+ # Returns a no argument function whose result is fed to self.evaluate to
+ # yield the next element
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ get_next = it.get_next()
+ return lambda: get_next
+
+ def _assert_datasets_equal(self, ds1, ds2):
assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
'%s') % (ds1.output_shapes,
ds2.output_shapes)
assert ds1.output_types == ds2.output_types
assert ds1.output_classes == ds2.output_classes
- next1 = ds1.make_one_shot_iterator().get_next()
- next2 = ds2.make_one_shot_iterator().get_next()
- with self.session(graph=g) as sess:
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = sess.run(next1)
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- sess.run(next2)
- break
- op2 = sess.run(next2)
- self.assertAllEqual(op1, op2)
+ next1 = self._get_next(ds1)
+ next2 = self._get_next(ds2)
+ # Run through datasets and check that outputs match, or errors match.
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except (errors.OutOfRangeError, ValueError) as e:
+ # If op1 throws an exception, check that op2 throws same exception.
+ with self.assertRaises(type(e)):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+ self.assertAllEqual(op1, op2)
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -95,33 +106,32 @@ class CsvDatasetOpTest(test.TestCase):
def _test_by_comparison(self, inputs, **kwargs):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
- with ops.Graph().as_default() as g:
- dataset_actual, dataset_expected = self._make_test_datasets(
- inputs, **kwargs)
- self._assert_datasets_equal(g, dataset_actual, dataset_expected)
+ dataset_actual, dataset_expected = self._make_test_datasets(
+ inputs, **kwargs)
+ self._assert_datasets_equal(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
- sess,
dataset,
expected_output=None,
expected_err_re=None):
- nxt = dataset.make_one_shot_iterator().get_next()
if expected_err_re is None:
# Verify that output is expected, without errors
+ nxt = self._get_next(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
for value in expected_output:
- op = sess.run(nxt)
+ op = self.evaluate(nxt())
self.assertAllEqual(op, value)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
+ self.evaluate(nxt())
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
+ nxt = self._get_next(dataset)
while True:
try:
- sess.run(nxt)
+ self.evaluate(nxt())
except errors.OutOfRangeError:
break
@@ -137,11 +147,8 @@ class CsvDatasetOpTest(test.TestCase):
# Convert str type because py3 tf strings are bytestrings
filenames = self._setup_files(inputs, linebreak, compression_type)
kwargs['compression_type'] = compression_type
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, **kwargs)
- self._verify_output_or_err(sess, dataset, expected_output,
- expected_err_re)
+ dataset = readers.CsvDataset(filenames, **kwargs)
+ self._verify_output_or_err(dataset, expected_output, expected_err_re)
def testCsvDataset_requiredFields(self):
record_defaults = [[]] * 4
@@ -191,21 +198,17 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults = [['']] * 3
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
filenames = self._setup_files(inputs)
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
filenames = self._setup_files(inputs)
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
record_defaults = [['']] * 3
@@ -351,10 +354,9 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- with ops.Graph().as_default() as g:
- self._assert_datasets_equal(g,
- ds_actual.repeat(5).prefetch(1),
- ds_expected.repeat(5).prefetch(1))
+ self._assert_datasets_equal(
+ ds_actual.repeat(5).prefetch(1),
+ ds_expected.repeat(5).prefetch(1))
def testCsvDataset_withTypeDefaults(self):
# Testing using dtypes as record_defaults for required fields
@@ -373,13 +375,11 @@ class CsvDatasetOpTest(test.TestCase):
]]
file_path = self._setup_files(data)
- with ops.Graph().as_default() as g:
- ds = readers.make_csv_dataset(
- file_path, batch_size=1, shuffle=False, num_epochs=1)
- next_batch = ds.make_one_shot_iterator().get_next()
+ ds = readers.make_csv_dataset(
+ file_path, batch_size=1, shuffle=False, num_epochs=1)
+ nxt = self._get_next(ds)
- with self.session(graph=g) as sess:
- result = list(sess.run(next_batch).values())
+ result = list(self.evaluate(nxt()).values())
self.assertEqual(result, sorted(result))
@@ -542,6 +542,29 @@ class CsvDatasetOpTest(test.TestCase):
compression_type='ZLIB',
record_defaults=record_defaults)
+ def testCsvDataset_withScalarDefaults(self):
+ record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_with2DDefaults(self):
+ record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ 'Each record default should be at '
+ 'most rank 1.')
+ else:
+ err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
+
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
class CsvDatasetBenchmark(test.Benchmark):
"""Benchmarks for the various ways of creating a dataset from CSV files.