aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
blob: b7025f3802c0c280981df20c86747e49fdf2274f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import itertools
import os
import time

import numpy as np

from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat

_NUMPY_RANDOM_SEED = 42


class MapDatasetTest(test.TestCase):

  def testMapIgnoreError(self):
    components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

    dataset = (
        dataset_ops.Dataset.from_tensor_slices(components)
        .map(lambda x: array_ops.check_numerics(x, "message")).apply(
            error_ops.ignore_errors()))
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.test_session() as sess:
      sess.run(init_op)
      for x in [1., 2., 3., 5.]:
        self.assertEqual(x, sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testParallelMapIgnoreError(self):
    components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)

    dataset = (
        dataset_ops.Dataset.from_tensor_slices(components).map(
            lambda x: array_ops.check_numerics(x, "message"),
            num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.test_session() as sess:
      sess.run(init_op)
      for x in [1., 2., 3., 5.]:
        self.assertEqual(x, sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testReadFileIgnoreError(self):
    def write_string_to_file(value, filename):
      with open(filename, "w") as f:
        f.write(value)
    filenames = [os.path.join(self.get_temp_dir(), "file_%d.txt" % i)
                 for i in range(5)]
    for filename in filenames:
      write_string_to_file(filename, filename)

    dataset = (
        dataset_ops.Dataset.from_tensor_slices(filenames).map(
            io_ops.read_file, num_parallel_calls=2).prefetch(2).apply(
                error_ops.ignore_errors()))
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.test_session() as sess:
      # All of the files are present.
      sess.run(init_op)
      for filename in filenames:
        self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Delete one of the files.
      os.remove(filenames[0])

      # Attempting to read filenames[0] will fail, but ignore_errors()
      # will catch the error.
      sess.run(init_op)
      for filename in filenames[1:]:
        self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testCaptureResourceInMapFn(self):

    def _build_ds(iterator):

      def _map_fn(x):
        get_next = iterator.get_next()
        return x * get_next

      return dataset_ops.Dataset.range(10).map(_map_fn)

    def _build_graph():
      captured_iterator = dataset_ops.Dataset.range(
          10).make_initializable_iterator()
      ds = _build_ds(captured_iterator)
      iterator = ds.make_initializable_iterator()
      init_op = iterator.initializer
      get_next = iterator.get_next()
      return captured_iterator.initializer, init_op, get_next

    with ops.Graph().as_default() as g:
      captured_init_op, init_op, get_next = _build_graph()
      with self.test_session(graph=g) as sess:
        sess.run(captured_init_op)
        sess.run(init_op)
        for i in range(10):
          self.assertEquals(i * i, sess.run(get_next))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)


class MapDatasetBenchmark(test.Benchmark):

  # The purpose of this benchmark is to compare the performance of chaining vs
  # fusing of the map and batch transformations across various configurations.
  #
  # NOTE: It is recommended to build the benchmark with
  # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
  # and execute it on a machine with at least 32 CPU cores.
  def benchmarkMapAndBatch(self):

    # Sequential pipeline configurations.
    seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
    seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])

    # Parallel pipeline configuration.
    par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
    par_batch_size_series = itertools.product([32], [32], [1],
                                              [128, 256, 512, 1024])
    par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
    par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])

    def name(method, label, num_calls, inter_op, element_size, batch_size):
      return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
          method,
          hashlib.sha1(label).hexdigest(),
          num_calls,
          inter_op,
          element_size,
          batch_size,
      ))

    def benchmark(label, series):

      print("%s:" % label)
      for num_calls, inter_op, element_size, batch_size in series:

        num_iters = 1024 // (
            (element_size * batch_size) // min(num_calls, inter_op))
        k = 1024 * 1024
        dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
            element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()

        chained_dataset = dataset.map(
            math_ops.matmul,
            num_parallel_calls=num_calls).batch(batch_size=batch_size)
        chained_iterator = chained_dataset.make_one_shot_iterator()
        chained_get_next = chained_iterator.get_next()

        chained_deltas = []
        with session.Session(
            config=config_pb2.ConfigProto(
                inter_op_parallelism_threads=inter_op,
                use_per_session_threads=True)) as sess:
          for _ in range(5):
            sess.run(chained_get_next.op)
          for _ in range(num_iters):
            start = time.time()
            sess.run(chained_get_next.op)
            end = time.time()
            chained_deltas.append(end - start)

        fused_dataset = dataset = dataset.apply(
            batching.map_and_batch(
                math_ops.matmul,
                num_parallel_calls=num_calls,
                batch_size=batch_size))
        fused_iterator = fused_dataset.make_one_shot_iterator()
        fused_get_next = fused_iterator.get_next()

        fused_deltas = []
        with session.Session(
            config=config_pb2.ConfigProto(
                inter_op_parallelism_threads=inter_op,
                use_per_session_threads=True)) as sess:

          for _ in range(5):
            sess.run(fused_get_next.op)
          for _ in range(num_iters):
            start = time.time()
            sess.run(fused_get_next.op)
            end = time.time()
            fused_deltas.append(end - start)

        print(
            "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
            "element size: %d, num iters: %d\nchained wall time: %f (median), "
            "%f (mean), %f (stddev), %f (min), %f (max)\n  fused wall time: "
            "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n    "
            "chained/fused:    %.2fx (median),    %.2fx (mean)" %
            (batch_size, num_calls, inter_op, element_size, num_iters,
             np.median(chained_deltas), np.mean(chained_deltas),
             np.std(chained_deltas), np.min(chained_deltas),
             np.max(chained_deltas), np.median(fused_deltas),
             np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
             np.max(fused_deltas),
             np.median(chained_deltas) / np.median(fused_deltas),
             np.mean(chained_deltas) / np.mean(fused_deltas)))

        self.report_benchmark(
            iters=num_iters,
            wall_time=np.median(chained_deltas),
            name=name("chained", label, num_calls, inter_op, element_size,
                      batch_size))

        self.report_benchmark(
            iters=num_iters,
            wall_time=np.median(fused_deltas),
            name=name("fused", label, num_calls, inter_op, element_size,
                      batch_size))

      print("")

    np.random.seed(_NUMPY_RANDOM_SEED)
    benchmark("Sequential element size evaluation", seq_elem_size_series)
    benchmark("Sequential batch size evaluation", seq_batch_size_series)
    benchmark("Parallel element size evaluation", par_elem_size_series)
    benchmark("Parallel batch size evaluation", par_batch_size_series)
    benchmark("Transformation parallelism evaluation", par_num_calls_series)
    benchmark("Threadpool size evaluation", par_inter_op_series)

if __name__ == "__main__":
  test.main()