aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
blob: 579096f88097ad9a724b029b7dfd74d04b75f90a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from os import path
import shutil
import tempfile

from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat


class ListFilesDatasetOpTest(test.TestCase):

  def setUp(self):
    self.tmp_dir = tempfile.mkdtemp()

  def tearDown(self):
    shutil.rmtree(self.tmp_dir, ignore_errors=True)

  def _touchTempFiles(self, filenames):
    for filename in filenames:
      open(path.join(self.tmp_dir, filename), 'a').close()

  def testEmptyDirectory(self):
    dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
    with self.test_session() as sess:
      itr = dataset.make_one_shot_iterator()
      next_element = itr.get_next()
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)

  def testSimpleDirectory(self):
    filenames = ['a', 'b', 'c']
    self._touchTempFiles(filenames)

    dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
    with self.test_session() as sess:
      itr = dataset.make_one_shot_iterator()
      next_element = itr.get_next()

      full_filenames = []
      produced_filenames = []
      for filename in filenames:
        full_filenames.append(
            compat.as_bytes(path.join(self.tmp_dir, filename)))
        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
      self.assertItemsEqual(full_filenames, produced_filenames)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())

  def testSimpleDirectoryNotShuffled(self):
    filenames = ['b', 'c', 'a']
    self._touchTempFiles(filenames)

    dataset = dataset_ops.Dataset.list_files(
        path.join(self.tmp_dir, '*'), shuffle=False)
    with self.test_session() as sess:
      itr = dataset.make_one_shot_iterator()
      next_element = itr.get_next()

      for filename in sorted(filenames):
        self.assertEqual(compat.as_bytes(path.join(self.tmp_dir, filename)),
                         sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())

  def testFixedSeedResultsInRepeatableOrder(self):
    filenames = ['a', 'b', 'c']
    self._touchTempFiles(filenames)

    dataset = dataset_ops.Dataset.list_files(
        path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
    with self.test_session() as sess:
      itr = dataset.make_initializable_iterator()
      next_element = itr.get_next()

      full_filenames = [compat.as_bytes(path.join(self.tmp_dir, filename))
                        for filename in filenames]

      all_produced_filenames = []
      for _ in range(3):
        produced_filenames = []
        sess.run(itr.initializer)
        try:
          while True:
            produced_filenames.append(sess.run(next_element))
        except errors.OutOfRangeError:
          pass
        all_produced_filenames.append(produced_filenames)

      # Each run should produce the same set of filenames, which may be
      # different from the order of `full_filenames`.
      self.assertItemsEqual(full_filenames, all_produced_filenames[0])
      # However, the different runs should produce filenames in the same order
      # as each other.
      self.assertEqual(all_produced_filenames[0], all_produced_filenames[1])
      self.assertEqual(all_produced_filenames[0], all_produced_filenames[2])

  def testEmptyDirectoryInitializer(self):
    filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    dataset = dataset_ops.Dataset.list_files(filename_placeholder)

    with self.test_session() as sess:
      itr = dataset.make_initializable_iterator()
      with self.assertRaisesRegexp(
          errors.InvalidArgumentError, 'No files matched pattern: '):
        sess.run(
            itr.initializer,
            feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})

  def testSimpleDirectoryInitializer(self):
    filenames = ['a', 'b', 'c']
    self._touchTempFiles(filenames)

    filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    dataset = dataset_ops.Dataset.list_files(filename_placeholder)

    with self.test_session() as sess:
      itr = dataset.make_initializable_iterator()
      next_element = itr.get_next()
      sess.run(
          itr.initializer,
          feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})

      full_filenames = []
      produced_filenames = []
      for filename in filenames:
        full_filenames.append(
            compat.as_bytes(path.join(self.tmp_dir, filename)))
        produced_filenames.append(compat.as_bytes(sess.run(next_element)))

      self.assertItemsEqual(full_filenames, produced_filenames)

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())

  def testFileSuffixes(self):
    filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
    self._touchTempFiles(filenames)

    filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    dataset = dataset_ops.Dataset.list_files(filename_placeholder)

    with self.test_session() as sess:
      itr = dataset.make_initializable_iterator()
      next_element = itr.get_next()
      sess.run(
          itr.initializer,
          feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')})

      full_filenames = []
      produced_filenames = []
      for filename in filenames[1:-1]:
        full_filenames.append(
            compat.as_bytes(path.join(self.tmp_dir, filename)))
        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
      self.assertItemsEqual(full_filenames, produced_filenames)

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())

  def testFileMiddles(self):
    filenames = ['a.txt', 'b.py', 'c.pyc']
    self._touchTempFiles(filenames)

    filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    dataset = dataset_ops.Dataset.list_files(filename_placeholder)

    with self.test_session() as sess:
      itr = dataset.make_initializable_iterator()
      next_element = itr.get_next()
      sess.run(
          itr.initializer,
          feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')})

      full_filenames = []
      produced_filenames = []
      for filename in filenames[1:]:
        full_filenames.append(
            compat.as_bytes(path.join(self.tmp_dir, filename)))
        produced_filenames.append(compat.as_bytes(sess.run(next_element)))

      self.assertItemsEqual(full_filenames, produced_filenames)

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())

  def testNoShuffle(self):
    filenames = ['a', 'b', 'c']
    self._touchTempFiles(filenames)

    # Repeat the list twice and ensure that the order is the same each time.
    # NOTE(mrry): This depends on an implementation detail of `list_files()`,
    # which is that the list of files is captured when the iterator is
    # initialized. Otherwise, or if e.g. the iterator were initialized more than
    # once, it's possible that the non-determinism of `tf.matching_files()`
    # would cause this test to fail. However, it serves as a useful confirmation
    # that the `shuffle=False` argument is working as intended.
    # TODO(b/73959787): Provide some ordering guarantees so that this test is
    # more meaningful.
    dataset = dataset_ops.Dataset.list_files(
        path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
    with self.test_session() as sess:
      itr = dataset.make_one_shot_iterator()
      next_element = itr.get_next()

      full_filenames = []
      produced_filenames = []
      for filename in filenames * 2:
        full_filenames.append(
            compat.as_bytes(path.join(self.tmp_dir, filename)))
        produced_filenames.append(compat.as_bytes(sess.run(next_element)))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(itr.get_next())
      self.assertItemsEqual(full_filenames, produced_filenames)
      self.assertEqual(produced_filenames[:len(filenames)],
                       produced_filenames[len(filenames):])


if __name__ == '__main__':
  test.main()