diff options
author | Derek Murray <mrry@google.com> | 2018-08-03 16:37:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 16:41:03 -0700 |
commit | 66dd14547dd9edb4eba13d22361ddad4a1cd3353 (patch) | |
tree | b43c2dc02b7c81fa4f4871732cad93b9b554052a | |
parent | 3229c87b35bac470c207872ca446829a83d6e629 (diff) |
[tf.data] Raise an InvalidArgumentError if the argument to Dataset.list_files() matches no files.
The main effect of this change is to change a late `OutOfRangeError` at iteration time into an earlier `InvalidArgumentError` at initialization time, which will improve error reporting in Eager mode and high-level APIs that control the training loop (Estimator, Keras, etc.). This change will break some advanced uses that concatenate many potentially empty file listings, but it is possible to work around this using `tf.data.Dataset.from_tensor_slices(tf.matching_files(file_pattern))`. We expect that the improved productivity from an earlier, more actionable error message will outweigh the inconvenience of modifying a small number of existing programs.
PiperOrigin-RevId: 207344116
-rw-r--r-- | tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/data/ops/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 41 |
3 files changed, 37 insertions, 18 deletions
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py index f7d7d085c9..579096f880 100644 --- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py @@ -123,13 +123,11 @@ class ListFilesDatasetOpTest(test.TestCase): with self.test_session() as sess: itr = dataset.make_initializable_iterator() - next_element = itr.get_next() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, 'No files matched pattern: '): + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) def testSimpleDirectoryInitializer(self): filenames = ['a', 'b', 'c'] diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD index 61bf9783ab..50ba5f403e 100644 --- a/tensorflow/python/data/ops/BUILD +++ b/tensorflow/python/data/ops/BUILD @@ -11,6 +11,7 @@ py_library( deps = [ ":iterator_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -19,6 +20,7 @@ py_library( "//tensorflow/python:random_seed", "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python:util", diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 88de4b588c..6cda2a77cc 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -39,10 +39,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import string_ops from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -644,17 +646,34 @@ class Dataset(object): Returns: Dataset: A `Dataset` of strings corresponding to file names. """ - if shuffle is None: - shuffle = True - matching_files = gen_io_ops.matching_files(file_pattern) - dataset = Dataset.from_tensor_slices(matching_files) - if shuffle: - # NOTE(mrry): The shuffle buffer size must be greater than zero, but the - # list of files might be empty. - buffer_size = math_ops.maximum( - array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) - dataset = dataset.shuffle(buffer_size, seed=seed) - return dataset + with ops.name_scope("list_files"): + if shuffle is None: + shuffle = True + file_pattern = ops.convert_to_tensor( + file_pattern, dtype=dtypes.string, name="file_pattern") + matching_files = gen_io_ops.matching_files(file_pattern) + + # Raise an exception if `file_pattern` does not match any files. + condition = math_ops.greater(array_ops.shape(matching_files)[0], 0, + name="match_not_empty") + + message = math_ops.add( + "No files matched pattern: ", + string_ops.reduce_join(file_pattern, separator=", "), name="message") + + assert_not_empty = control_flow_ops.Assert( + condition, [message], summarize=1, name="assert_not_empty") + with ops.control_dependencies([assert_not_empty]): + matching_files = array_ops.identity(matching_files) + + dataset = Dataset.from_tensor_slices(matching_files) + if shuffle: + # NOTE(mrry): The shuffle buffer size must be greater than zero, but the + # list of files might be empty. + buffer_size = math_ops.maximum( + array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1) + dataset = dataset.shuffle(buffer_size, seed=seed) + return dataset def repeat(self, count=None): """Repeats this dataset `count` times. |