diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-13 14:18:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 14:26:13 -0700 |
commit | d3458112ad5a1612ec6c77f7de4a0e0ec801e882 (patch) | |
tree | d1b265fed67f2f4a9b502ea6eb32fe7c26eab2ee /tensorflow/contrib/data | |
parent | 885cd2942ae7b6239146a3f51ec3d6948ac2b89e (diff) |
Consistency in record_default shapes for tf.contrib.data.CsvDataset & tf.decode_csv:
- Modify shape assertions so that both graph and eager accept rank 0 (scalar) and rank 1 tensors as `record_defaults`, and raise an error on other shapes.
- Make tests run in both graph and eager modes
Fixes #22030.
PiperOrigin-RevId: 212877058
Diffstat (limited to 'tensorflow/contrib/data')
4 files changed, 85 insertions, 52 deletions
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 74107d5242..21ec50fb6b 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -49,6 +49,9 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults_list)); for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, errors::InvalidArgument( "There should only be 1 default per field but field ", i, diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index ae104d55bd..ad410e17fe 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -65,7 +65,13 @@ REGISTER_OP("CSVDataset") TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); // `record_defaults` must be lists of scalars for (size_t i = 8; i < c->num_inputs(); ++i) { - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + shape_inference::ShapeHandle v; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); + if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { + return errors::InvalidArgument( + "Shape of a default must be a length-0 or length-1 vector, or a " + "scalar."); + } } return shape_inference::ScalarShape(c); }); diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index b3c90ded39..ba202839b2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -72,12 +72,13 @@ py_test( "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:context", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 63bffd023f..f8e74e4583 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -31,38 +31,49 @@ from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest from tensorflow.python.platform import test +@test_util.run_all_in_graph_and_eager_modes class CsvDatasetOpTest(test.TestCase): - def _assert_datasets_equal(self, g, ds1, ds2): + def _get_next(self, dataset): + # Returns a no argument function whose result is fed to self.evaluate to + # yield the next element + it = dataset.make_one_shot_iterator() + if context.executing_eagerly(): + return it.get_next + else: + get_next = it.get_next() + return lambda: get_next + + def _assert_datasets_equal(self, ds1, ds2): assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' '%s') % (ds1.output_shapes, ds2.output_shapes) assert ds1.output_types == ds2.output_types assert ds1.output_classes == ds2.output_classes - next1 = ds1.make_one_shot_iterator().get_next() - next2 = ds2.make_one_shot_iterator().get_next() - with self.session(graph=g) as sess: - # Run through datasets and check that outputs match, or errors match. - while True: - try: - op1 = sess.run(next1) - except (errors.OutOfRangeError, ValueError) as e: - # If op1 throws an exception, check that op2 throws same exception. - with self.assertRaises(type(e)): - sess.run(next2) - break - op2 = sess.run(next2) - self.assertAllEqual(op1, op2) + next1 = self._get_next(ds1) + next2 = self._get_next(ds2) + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = self.evaluate(next1()) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + self.evaluate(next2()) + break + op2 = self.evaluate(next2()) + self.assertAllEqual(op1, op2) def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] @@ -95,33 +106,32 @@ class CsvDatasetOpTest(test.TestCase): def _test_by_comparison(self, inputs, **kwargs): """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" - with ops.Graph().as_default() as g: - dataset_actual, dataset_expected = self._make_test_datasets( - inputs, **kwargs) - self._assert_datasets_equal(g, dataset_actual, dataset_expected) + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(dataset_actual, dataset_expected) def _verify_output_or_err(self, - sess, dataset, expected_output=None, expected_err_re=None): - nxt = dataset.make_one_shot_iterator().get_next() if expected_err_re is None: # Verify that output is expected, without errors + nxt = self._get_next(dataset) expected_output = [[ v.encode('utf-8') if isinstance(v, str) else v for v in op ] for op in expected_output] for value in expected_output: - op = sess.run(nxt) + op = self.evaluate(nxt()) self.assertAllEqual(op, value) with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) + self.evaluate(nxt()) else: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): + nxt = self._get_next(dataset) while True: try: - sess.run(nxt) + self.evaluate(nxt()) except errors.OutOfRangeError: break @@ -137,11 +147,8 @@ class CsvDatasetOpTest(test.TestCase): # Convert str type because py3 tf strings are bytestrings filenames = self._setup_files(inputs, linebreak, compression_type) kwargs['compression_type'] = compression_type - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, **kwargs) - self._verify_output_or_err(sess, dataset, expected_output, - expected_err_re) + dataset = readers.CsvDataset(filenames, **kwargs) + self._verify_output_or_err(dataset, expected_output, expected_err_re) def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 @@ -191,21 +198,17 @@ class CsvDatasetOpTest(test.TestCase): record_defaults = [['']] * 3 inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] filenames = self._setup_files(inputs) - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) def testCsvDataset_ignoreErrWithUnquotedQuotes(self): record_defaults = [['']] * 3 inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] filenames = self._setup_files(inputs) - with ops.Graph().as_default() as g: - with self.session(graph=g) as sess: - dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) - dataset = dataset.apply(error_ops.ignore_errors()) - self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(dataset, [['e', 'f', 'g']]) def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): record_defaults = [['']] * 3 @@ -351,10 +354,9 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1,,3,4', '5,6,,8']] ds_actual, ds_expected = self._make_test_datasets( inputs, record_defaults=record_defaults) - with ops.Graph().as_default() as g: - self._assert_datasets_equal(g, - ds_actual.repeat(5).prefetch(1), - ds_expected.repeat(5).prefetch(1)) + self._assert_datasets_equal( + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields @@ -373,13 +375,11 @@ class CsvDatasetOpTest(test.TestCase): ]] file_path = self._setup_files(data) - with ops.Graph().as_default() as g: - ds = readers.make_csv_dataset( - file_path, batch_size=1, shuffle=False, num_epochs=1) - next_batch = ds.make_one_shot_iterator().get_next() + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + nxt = self._get_next(ds) - with self.session(graph=g) as sess: - result = list(sess.run(next_batch).values()) + result = list(self.evaluate(nxt()).values()) self.assertEqual(result, sorted(result)) @@ -542,6 +542,29 @@ class CsvDatasetOpTest(test.TestCase): compression_type='ZLIB', record_defaults=record_defaults) + def testCsvDataset_withScalarDefaults(self): + record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_with2DDefaults(self): + record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + + if context.executing_eagerly(): + err_spec = errors.InvalidArgumentError, ( + 'Each record default should be at ' + 'most rank 1.') + else: + err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2' + + with self.assertRaisesWithPredicateMatch(*err_spec): + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. |