aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/summary/event_accumulator_test.py
blob: 2050c346cdc25b3b1f4d47b34caa0d270b31b031 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.core.framework import graph_pb2
from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
from tensorflow.python.platform import logging
from tensorflow.python.summary import event_accumulator as ea


class _EventGenerator(object):

  def __init__(self):
    self.items = []

  def Load(self):
    while self.items:
      yield self.items.pop(0)

  def AddScalar(self, tag, wall_time=0, step=0, value=0):
    event = tf.Event(
        wall_time=wall_time,
        step=step,
        summary=tf.Summary(value=[tf.Summary.Value(tag=tag,
                                                   simple_value=value)]))
    self.AddEvent(event)

  def AddHistogram(self,
                   tag,
                   wall_time=0,
                   step=0,
                   hmin=1,
                   hmax=2,
                   hnum=3,
                   hsum=4,
                   hsum_squares=5,
                   hbucket_limit=None,
                   hbucket=None):
    histo = tf.HistogramProto(min=hmin,
                              max=hmax,
                              num=hnum,
                              sum=hsum,
                              sum_squares=hsum_squares,
                              bucket_limit=hbucket_limit,
                              bucket=hbucket)
    event = tf.Event(wall_time=wall_time,
                     step=step,
                     summary=tf.Summary(value=[tf.Summary.Value(tag=tag,
                                                                histo=histo)]))
    self.AddEvent(event)

  def AddImage(self,
               tag,
               wall_time=0,
               step=0,
               encoded_image_string=b'imgstr',
               width=150,
               height=100):
    image = tf.Summary.Image(encoded_image_string=encoded_image_string,
                             width=width,
                             height=height)
    event = tf.Event(wall_time=wall_time,
                     step=step,
                     summary=tf.Summary(value=[tf.Summary.Value(tag=tag,
                                                                image=image)]))
    self.AddEvent(event)

  def AddEvent(self, event):
    self.items.append(event)


class EventAccumulatorTest(tf.test.TestCase):

  def assertTagsEqual(self, tags1, tags2):
    # Make sure the two dictionaries have the same keys.
    self.assertItemsEqual(tags1, tags2)
    # Additionally, make sure each key in the dictionary maps to the same value.
    for key in tags1:
      if isinstance(tags1[key], list):
        # We don't care about the order of the values in lists, thus asserting
        # only if the items are equal.
        self.assertItemsEqual(tags1[key], tags2[key])
      else:
        # Make sure the values are equal.
        self.assertEqual(tags1[key], tags2[key])


class MockingEventAccumulatorTest(EventAccumulatorTest):

  def setUp(self):
    super(MockingEventAccumulatorTest, self).setUp()
    self.stubs = googletest.StubOutForTesting()
    self.empty = {ea.IMAGES: [],
                  ea.SCALARS: [],
                  ea.HISTOGRAMS: [],
                  ea.COMPRESSED_HISTOGRAMS: [],
                  ea.GRAPH: False,
                  ea.RUN_METADATA: []}
    self._real_constructor = ea.EventAccumulator
    self._real_generator = ea._GeneratorFromPath

    def _FakeAccumulatorConstructor(generator, *args, **kwargs):
      ea._GeneratorFromPath = lambda x: generator
      return self._real_constructor(generator, *args, **kwargs)

    ea.EventAccumulator = _FakeAccumulatorConstructor

  def tearDown(self):
    self.stubs.CleanUp()
    ea.EventAccumulator = self._real_constructor
    ea._GeneratorFromPath = self._real_generator

  def testEmptyAccumulator(self):
    gen = _EventGenerator()
    x = ea.EventAccumulator(gen)
    x.Reload()
    self.assertEqual(x.Tags(), self.empty)

  def testTags(self):
    gen = _EventGenerator()
    gen.AddScalar('s1')
    gen.AddScalar('s2')
    gen.AddHistogram('hst1')
    gen.AddHistogram('hst2')
    gen.AddImage('im1')
    gen.AddImage('im2')
    acc = ea.EventAccumulator(gen)
    acc.Reload()
    self.assertTagsEqual(acc.Tags(), {
        ea.IMAGES: ['im1', 'im2'],
        ea.SCALARS: ['s1', 's2'],
        ea.HISTOGRAMS: ['hst1', 'hst2'],
        ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
        ea.GRAPH: False,
        ea.RUN_METADATA: []
    })

  def testReload(self):
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    acc.Reload()
    self.assertEqual(acc.Tags(), self.empty)
    gen.AddScalar('s1')
    gen.AddScalar('s2')
    gen.AddHistogram('hst1')
    gen.AddHistogram('hst2')
    gen.AddImage('im1')
    gen.AddImage('im2')
    self.assertEqual(acc.Tags(), self.empty)
    acc.Reload()
    self.assertTagsEqual(acc.Tags(), {
        ea.IMAGES: ['im1', 'im2'],
        ea.SCALARS: ['s1', 's2'],
        ea.HISTOGRAMS: ['hst1', 'hst2'],
        ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'],
        ea.GRAPH: False,
        ea.RUN_METADATA: []
    })

  def testScalars(self):
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    s1 = ea.ScalarEvent(wall_time=1, step=10, value=32)
    s2 = ea.ScalarEvent(wall_time=2, step=12, value=64)
    gen.AddScalar('s1', wall_time=1, step=10, value=32)
    gen.AddScalar('s2', wall_time=2, step=12, value=64)
    acc.Reload()
    self.assertEqual(acc.Scalars('s1'), [s1])
    self.assertEqual(acc.Scalars('s2'), [s2])

  def testHistograms(self):
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)

    val1 = ea.HistogramValue(min=1,
                             max=2,
                             num=3,
                             sum=4,
                             sum_squares=5,
                             bucket_limit=[1, 2, 3],
                             bucket=[0, 3, 0])
    val2 = ea.HistogramValue(min=-2,
                             max=3,
                             num=4,
                             sum=5,
                             sum_squares=6,
                             bucket_limit=[2, 3, 4],
                             bucket=[1, 3, 0])

    hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1)
    hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2)
    gen.AddHistogram('hst1',
                     wall_time=1,
                     step=10,
                     hmin=1,
                     hmax=2,
                     hnum=3,
                     hsum=4,
                     hsum_squares=5,
                     hbucket_limit=[1, 2, 3],
                     hbucket=[0, 3, 0])
    gen.AddHistogram('hst2',
                     wall_time=2,
                     step=12,
                     hmin=-2,
                     hmax=3,
                     hnum=4,
                     hsum=5,
                     hsum_squares=6,
                     hbucket_limit=[2, 3, 4],
                     hbucket=[1, 3, 0])
    acc.Reload()
    self.assertEqual(acc.Histograms('hst1'), [hst1])
    self.assertEqual(acc.Histograms('hst2'), [hst2])

  def testCompressedHistograms(self):
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000))

    gen.AddHistogram('hst1',
                     wall_time=1,
                     step=10,
                     hmin=1,
                     hmax=2,
                     hnum=3,
                     hsum=4,
                     hsum_squares=5,
                     hbucket_limit=[1, 2, 3],
                     hbucket=[0, 3, 0])
    gen.AddHistogram('hst2',
                     wall_time=2,
                     step=12,
                     hmin=-2,
                     hmax=3,
                     hnum=4,
                     hsum=5,
                     hsum_squares=6,
                     hbucket_limit=[2, 3, 4],
                     hbucket=[1, 3, 0])
    acc.Reload()

    # Create the expected values after compressing hst1
    expected_vals1 = [ea.CompressedHistogramValue(bp, val)
                      for bp, val in [(0, 1.0), (2500, 1.25), (5000, 1.5), (
                          7500, 1.75), (10000, 2.0)]]
    expected_cmphst1 = ea.CompressedHistogramEvent(
        wall_time=1,
        step=10,
        compressed_histogram_values=expected_vals1)
    self.assertEqual(acc.CompressedHistograms('hst1'), [expected_cmphst1])

    # Create the expected values after compressing hst2
    expected_vals2 = [
        ea.CompressedHistogramValue(bp, val)
        for bp, val in [(0, -2), (2500, 2), (5000, 2 + 1 / 3), (7500, 2 + 2 / 3
                                                               ), (10000, 3)]
    ]
    expected_cmphst2 = ea.CompressedHistogramEvent(
        wall_time=2,
        step=12,
        compressed_histogram_values=expected_vals2)
    self.assertEqual(acc.CompressedHistograms('hst2'), [expected_cmphst2])

  def testPercentile(self):

    def AssertExpectedForBps(bps, expected):
      output = acc._Percentile(bps, bucket_limit, cumsum_weights, histo_min,
                               histo_max, histo_num)
      self.assertAlmostEqual(expected, output)

    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)

    bucket_limit = [1, 2, 3, 4]
    histo_num = 100

    ## All weights in the first bucket
    cumsum_weights = [10000, 10000, 10000, 10000]
    histo_min = -1
    histo_max = .9
    AssertExpectedForBps(0, histo_min)
    AssertExpectedForBps(2500, ea._Remap(2500, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(5000, ea._Remap(5000, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(7500, ea._Remap(7500, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(10000, histo_max)

    ## All weights in second bucket
    cumsum_weights = [0, 10000, 10000, 10000]
    histo_min = 1.1
    histo_max = 1.8
    AssertExpectedForBps(0, histo_min)
    AssertExpectedForBps(2500, ea._Remap(2500, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(5000, ea._Remap(5000, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(7500, ea._Remap(7500, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(10000, histo_max)

    ## All weights in the last bucket
    cumsum_weights = [0, 0, 0, 10000]
    histo_min = 3.1
    histo_max = 3.6
    AssertExpectedForBps(0, histo_min)
    AssertExpectedForBps(2500, ea._Remap(2500, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(5000, ea._Remap(5000, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(7500, ea._Remap(7500, 0, 10000, histo_min, histo_max))
    AssertExpectedForBps(10000, histo_max)

    ## Weights distributed between two buckets
    cumsum_weights = [0, 4000, 10000, 10000]
    histo_min = 1.1
    histo_max = 2.9
    AssertExpectedForBps(0, histo_min)
    AssertExpectedForBps(2500, ea._Remap(2500, 0, 4000, histo_min,
                                         bucket_limit[1]))
    AssertExpectedForBps(5000, ea._Remap(5000, 4000, 10000, bucket_limit[1],
                                         histo_max))
    AssertExpectedForBps(7500, ea._Remap(7500, 4000, 10000, bucket_limit[1],
                                         histo_max))
    AssertExpectedForBps(10000, histo_max)

    ## Weights distributed between all buckets
    cumsum_weights = [1000, 4000, 8000, 10000]
    histo_min = -1
    histo_max = 3.9
    AssertExpectedForBps(0, histo_min)
    AssertExpectedForBps(2500, ea._Remap(2500, 1000, 4000, bucket_limit[0],
                                         bucket_limit[1]))
    AssertExpectedForBps(5000, ea._Remap(5000, 4000, 8000, bucket_limit[1],
                                         bucket_limit[2]))
    AssertExpectedForBps(7500, ea._Remap(7500, 4000, 8000, bucket_limit[1],
                                         bucket_limit[2]))
    AssertExpectedForBps(9000, ea._Remap(9000, 8000, 10000, bucket_limit[2],
                                         histo_max))
    AssertExpectedForBps(10000, histo_max)

    ## Most weight in first bucket
    cumsum_weights = [9000, 10000, 10000, 10000]
    histo_min = -1
    histo_max = 1.1
    AssertExpectedForBps(0, histo_min)
    AssertExpectedForBps(2500, ea._Remap(2500, 0, 9000, histo_min,
                                         bucket_limit[0]))
    AssertExpectedForBps(5000, ea._Remap(5000, 0, 9000, histo_min,
                                         bucket_limit[0]))
    AssertExpectedForBps(7500, ea._Remap(7500, 0, 9000, histo_min,
                                         bucket_limit[0]))
    AssertExpectedForBps(9500, ea._Remap(9500, 9000, 10000, bucket_limit[0],
                                         histo_max))
    AssertExpectedForBps(10000, histo_max)

  def testImages(self):
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    im1 = ea.ImageEvent(wall_time=1,
                        step=10,
                        encoded_image_string=b'big',
                        width=400,
                        height=300)
    im2 = ea.ImageEvent(wall_time=2,
                        step=12,
                        encoded_image_string=b'small',
                        width=40,
                        height=30)
    gen.AddImage('im1',
                 wall_time=1,
                 step=10,
                 encoded_image_string=b'big',
                 width=400,
                 height=300)
    gen.AddImage('im2',
                 wall_time=2,
                 step=12,
                 encoded_image_string=b'small',
                 width=40,
                 height=30)
    acc.Reload()
    self.assertEqual(acc.Images('im1'), [im1])
    self.assertEqual(acc.Images('im2'), [im2])

  def testActivation(self):
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    self.assertFalse(acc._activated)
    with self.assertRaises(RuntimeError):
      acc.Tags()
    with self.assertRaises(RuntimeError):
      acc.Scalars('s1')
    acc.Reload()
    self.assertTrue(acc._activated)
    acc._activated = False

  def testKeyError(self):
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    acc.Reload()
    with self.assertRaises(KeyError):
      acc.Scalars('s1')
    with self.assertRaises(KeyError):
      acc.Scalars('hst1')
    with self.assertRaises(KeyError):
      acc.Scalars('im1')
    with self.assertRaises(KeyError):
      acc.Histograms('s1')
    with self.assertRaises(KeyError):
      acc.Histograms('im1')
    with self.assertRaises(KeyError):
      acc.Images('s1')
    with self.assertRaises(KeyError):
      acc.Images('hst1')

  def testNonValueEvents(self):
    """Tests that non-value events in the generator don't cause early exits."""
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    gen.AddScalar('s1', wall_time=1, step=10, value=20)
    gen.AddEvent(tf.Event(wall_time=2, step=20, file_version='nots2'))
    gen.AddScalar('s3', wall_time=3, step=100, value=1)
    gen.AddHistogram('hst1')
    gen.AddImage('im1')

    acc.Reload()
    self.assertTagsEqual(acc.Tags(), {
        ea.IMAGES: ['im1'],
        ea.SCALARS: ['s1', 's3'],
        ea.HISTOGRAMS: ['hst1'],
        ea.COMPRESSED_HISTOGRAMS: ['hst1'],
        ea.GRAPH: False,
        ea.RUN_METADATA: []
    })

  def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self):
    """Tests that events are discarded after a restart is detected.

    If a step value is observed to be lower than what was previously seen,
    this should force a discard of all previous items with the same tag
    that are outdated.

    Only file versions < 2 use this out-of-order discard logic. Later versions
    discard events based on the step value of SessionLog.START.
    """
    warnings = []
    self.stubs.Set(logging, 'warn', warnings.append)

    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)

    gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
    gen.AddScalar('s1', wall_time=1, step=100, value=20)
    gen.AddScalar('s1', wall_time=1, step=200, value=20)
    gen.AddScalar('s1', wall_time=1, step=300, value=20)
    acc.Reload()
    ## Check that number of items are what they should be
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300])

    gen.AddScalar('s1', wall_time=1, step=101, value=20)
    gen.AddScalar('s1', wall_time=1, step=201, value=20)
    gen.AddScalar('s1', wall_time=1, step=301, value=20)
    acc.Reload()
    ## Check that we have discarded 200 and 300 from s1
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])

  def testOrphanedDataNotDiscardedIfFlagUnset(self):
    """Tests that events are not discarded if purge_orphaned_data is false.
    """
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen, purge_orphaned_data=False)

    gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
    gen.AddScalar('s1', wall_time=1, step=100, value=20)
    gen.AddScalar('s1', wall_time=1, step=200, value=20)
    gen.AddScalar('s1', wall_time=1, step=300, value=20)
    acc.Reload()
    ## Check that number of items are what they should be
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300])

    gen.AddScalar('s1', wall_time=1, step=101, value=20)
    gen.AddScalar('s1', wall_time=1, step=201, value=20)
    gen.AddScalar('s1', wall_time=1, step=301, value=20)
    acc.Reload()
    ## Check that we have discarded 200 and 300 from s1
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300, 101,
                                                           201, 301])

  def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self):
    """Tests that event discards after restart, only affect the misordered tag.

    If a step value is observed to be lower than what was previously seen,
    this should force a discard of all previous items that are outdated, but
    only for the out of order tag. Other tags should remain unaffected.

    Only file versions < 2 use this out-of-order discard logic. Later versions
    discard events based on the step value of SessionLog.START.
    """
    warnings = []
    self.stubs.Set(logging, 'warn', warnings.append)

    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)

    gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1'))
    gen.AddScalar('s1', wall_time=1, step=100, value=20)
    gen.AddScalar('s1', wall_time=1, step=200, value=20)
    gen.AddScalar('s1', wall_time=1, step=300, value=20)
    gen.AddScalar('s1', wall_time=1, step=101, value=20)
    gen.AddScalar('s1', wall_time=1, step=201, value=20)
    gen.AddScalar('s1', wall_time=1, step=301, value=20)

    gen.AddScalar('s2', wall_time=1, step=101, value=20)
    gen.AddScalar('s2', wall_time=1, step=201, value=20)
    gen.AddScalar('s2', wall_time=1, step=301, value=20)

    acc.Reload()
    ## Check that we have discarded 200 and 300
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])

    ## Check that s1 discards do not affect s2
    ## i.e. check that only events from the out of order tag are discarded
    self.assertEqual([x.step for x in acc.Scalars('s2')], [101, 201, 301])

  def testOnlySummaryEventsTriggerDiscards(self):
    """Test that file version event does not trigger data purge."""
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    gen.AddScalar('s1', wall_time=1, step=100, value=20)
    ev1 = tf.Event(wall_time=2, step=0, file_version='brain.Event:1')
    graph_bytes = graph_pb2.GraphDef().SerializeToString()
    ev2 = tf.Event(wall_time=3, step=0, graph_def=graph_bytes)
    gen.AddEvent(ev1)
    gen.AddEvent(ev2)
    acc.Reload()
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100])

  def testSessionLogStartMessageDiscardsExpiredEvents(self):
    """Test that SessionLog.START message discards expired events.

    This discard logic is preferred over the out-of-order step discard logic,
    but this logic can only be used for event protos which have the SessionLog
    enum, which was introduced to event.proto for file_version >= brain.Event:2.
    """
    gen = _EventGenerator()
    acc = ea.EventAccumulator(gen)
    gen.AddEvent(tf.Event(wall_time=0, step=1, file_version='brain.Event:2'))

    gen.AddScalar('s1', wall_time=1, step=100, value=20)
    gen.AddScalar('s1', wall_time=1, step=200, value=20)
    gen.AddScalar('s1', wall_time=1, step=300, value=20)
    gen.AddScalar('s1', wall_time=1, step=400, value=20)

    gen.AddScalar('s2', wall_time=1, step=202, value=20)
    gen.AddScalar('s2', wall_time=1, step=203, value=20)

    slog = SessionLog(status=SessionLog.START)
    gen.AddEvent(tf.Event(wall_time=2, step=201, session_log=slog))
    acc.Reload()
    self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200])
    self.assertEqual([x.step for x in acc.Scalars('s2')], [])


class RealisticEventAccumulatorTest(EventAccumulatorTest):

  def setUp(self):
    super(RealisticEventAccumulatorTest, self).setUp()

  def testScalarsRealistically(self):
    """Test accumulator by writing values and then reading them."""

    def FakeScalarSummary(tag, value):
      value = tf.Summary.Value(tag=tag, simple_value=value)
      summary = tf.Summary(value=[value])
      return summary

    directory = os.path.join(self.get_temp_dir(), 'values_dir')
    if gfile.IsDirectory(directory):
      gfile.DeleteRecursively(directory)
    gfile.MkDir(directory)

    writer = tf.train.SummaryWriter(directory, max_queue=100)
    graph_def = tf.GraphDef(node=[tf.NodeDef(name='A', op='Mul')])
    # Add a graph to the summary writer.
    writer.add_graph(graph_def)

    run_metadata = tf.RunMetadata()
    device_stats = run_metadata.step_stats.dev_stats.add()
    device_stats.device = 'test device'
    writer.add_run_metadata(run_metadata, 'test run')

    # Write a bunch of events using the writer
    for i in xrange(30):
      summ_id = FakeScalarSummary('id', i)
      summ_sq = FakeScalarSummary('sq', i * i)
      writer.add_summary(summ_id, i * 5)
      writer.add_summary(summ_sq, i * 5)
    writer.flush()

    # Verify that we can load those events properly
    acc = ea.EventAccumulator(directory)
    acc.Reload()
    self.assertTagsEqual(acc.Tags(), {
        ea.IMAGES: [],
        ea.SCALARS: ['id', 'sq'],
        ea.HISTOGRAMS: [],
        ea.COMPRESSED_HISTOGRAMS: [],
        ea.GRAPH: True,
        ea.RUN_METADATA: ['test run']
    })
    id_events = acc.Scalars('id')
    sq_events = acc.Scalars('sq')
    self.assertEqual(30, len(id_events))
    self.assertEqual(30, len(sq_events))
    for i in xrange(30):
      self.assertEqual(i * 5, id_events[i].step)
      self.assertEqual(i * 5, sq_events[i].step)
      self.assertEqual(i, id_events[i].value)
      self.assertEqual(i * i, sq_events[i].value)

    # Write a few more events to test incremental reloading
    for i in xrange(30, 40):
      summ_id = FakeScalarSummary('id', i)
      summ_sq = FakeScalarSummary('sq', i * i)
      writer.add_summary(summ_id, i * 5)
      writer.add_summary(summ_sq, i * 5)
    writer.flush()

    # Verify we can now see all of the data
    acc.Reload()
    self.assertEqual(40, len(id_events))
    self.assertEqual(40, len(sq_events))
    for i in xrange(40):
      self.assertEqual(i * 5, id_events[i].step)
      self.assertEqual(i * 5, sq_events[i].step)
      self.assertEqual(i, id_events[i].value)
      self.assertEqual(i * i, sq_events[i].value)
    self.assertProtoEquals(graph_def, acc.Graph())


if __name__ == '__main__':
  tf.test.main()