aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-03 16:37:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 16:41:03 -0700
commit66dd14547dd9edb4eba13d22361ddad4a1cd3353 (patch)
treeb43c2dc02b7c81fa4f4871732cad93b9b554052a
parent3229c87b35bac470c207872ca446829a83d6e629 (diff)
[tf.data] Raise an InvalidArgumentError if the argument to Dataset.list_files() matches no files.
The main effect of this change is to change a late `OutOfRangeError` at iteration time into an earlier `InvalidArgumentError` at initialization time, which will improve error reporting in Eager mode and high-level APIs that control the training loop (Estimator, Keras, etc.). This change will break some advanced uses that concatenate many potentially empty file listings, but it is possible to work around this using `tf.data.Dataset.from_tensor_slices(tf.matching_files(file_pattern))`. We expect that the improved productivity from an earlier, more actionable error message will outweigh the inconvenience of modifying a small number of existing programs. PiperOrigin-RevId: 207344116
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py12
-rw-r--r--tensorflow/python/data/ops/BUILD2
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py41
3 files changed, 37 insertions, 18 deletions
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index f7d7d085c9..579096f880 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -123,13 +123,11 @@ class ListFilesDatasetOpTest(test.TestCase):
with self.test_session() as sess:
itr = dataset.make_initializable_iterator()
- next_element = itr.get_next()
- sess.run(
- itr.initializer,
- feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError, 'No files matched pattern: '):
+ sess.run(
+ itr.initializer,
+ feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
def testSimpleDirectoryInitializer(self):
filenames = ['a', 'b', 'c']
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 61bf9783ab..50ba5f403e 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -11,6 +11,7 @@ py_library(
deps = [
":iterator_ops",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
@@ -19,6 +20,7 @@ py_library(
"//tensorflow/python:random_seed",
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 88de4b588c..6cda2a77cc 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -39,10 +39,12 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -644,17 +646,34 @@ class Dataset(object):
Returns:
Dataset: A `Dataset` of strings corresponding to file names.
"""
- if shuffle is None:
- shuffle = True
- matching_files = gen_io_ops.matching_files(file_pattern)
- dataset = Dataset.from_tensor_slices(matching_files)
- if shuffle:
- # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
- # list of files might be empty.
- buffer_size = math_ops.maximum(
- array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
- dataset = dataset.shuffle(buffer_size, seed=seed)
- return dataset
+ with ops.name_scope("list_files"):
+ if shuffle is None:
+ shuffle = True
+ file_pattern = ops.convert_to_tensor(
+ file_pattern, dtype=dtypes.string, name="file_pattern")
+ matching_files = gen_io_ops.matching_files(file_pattern)
+
+ # Raise an exception if `file_pattern` does not match any files.
+ condition = math_ops.greater(array_ops.shape(matching_files)[0], 0,
+ name="match_not_empty")
+
+ message = math_ops.add(
+ "No files matched pattern: ",
+ string_ops.reduce_join(file_pattern, separator=", "), name="message")
+
+ assert_not_empty = control_flow_ops.Assert(
+ condition, [message], summarize=1, name="assert_not_empty")
+ with ops.control_dependencies([assert_not_empty]):
+ matching_files = array_ops.identity(matching_files)
+
+ dataset = Dataset.from_tensor_slices(matching_files)
+ if shuffle:
+ # NOTE(mrry): The shuffle buffer size must be greater than zero, but the
+ # list of files might be empty.
+ buffer_size = math_ops.maximum(
+ array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
+ dataset = dataset.shuffle(buffer_size, seed=seed)
+ return dataset
def repeat(self, count=None):
"""Repeats this dataset `count` times.