aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tools/accuracy
diff options
context:
space:
mode:
authorGravatar Shashi Shekhar <shashishekhar@google.com>2018-08-30 12:20:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 12:29:52 -0700
commit579315f63b7fd6ea03ee63d81740286cdff2bb6b (patch)
treee55e5147c98375fb4b7d705e9538a58b4c8d6f2e /tensorflow/contrib/lite/tools/accuracy
parent2fbbbb94b0c683f9eed59c3b560d7a3d3f8fc94e (diff)
Add parallel evaluation and blacklist support.
PiperOrigin-RevId: 210957416
Diffstat (limited to 'tensorflow/contrib/lite/tools/accuracy')
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD2
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md8
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt1762
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc53
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc228
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h29
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc25
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h11
8 files changed, 2036 insertions, 82 deletions
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
index db4b688a45..14eb55b036 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
@@ -85,6 +85,7 @@ cc_library(
],
"//conditions:default": [
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
],
},
),
@@ -137,6 +138,7 @@ cc_library(
],
"//conditions:default": [
"//tensorflow/core:tensorflow",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:framework_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
index 9b3b99451d..362ea3ac34 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
@@ -30,9 +30,17 @@ The binary takes the following parameters:
This is the path to the output file. The output is a CSV file that has top-10 accuracies in each row. Each line of output file is the cumulative accuracy after processing images in a sorted order. So first line is accuracy after processing the first image, second line is accuracy after procesing first two images. The last line of the file is accuracy after processing the entire validation set.
and the following optional parameters:
+
+* `blacklist_file_path`: `string` \
+ Path to blacklist file. This file contains the indices of images that are blacklisted for evaluation. 1762 images are blacklisted in ILSVRC dataset. For details please refer to readme.txt of ILSVRC2014 devkit.
+
* `num_images`: `int` (default=0) \
The number of images to process, if 0, all images in the directory are processed otherwise only num_images will be processed.
+* `num_threads`: `int` (default=4) \
+ The number of threads to use for evaluation.
+
+
## Downloading ILSVRC
In order to use this tool to run evaluation on the full 50K ImageNet dataset,
download the data set from http://image-net.org/request.
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
new file mode 100644
index 0000000000..b2f00e034e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
@@ -0,0 +1,1762 @@
+36
+50
+56
+103
+127
+195
+199
+226
+230
+235
+251
+254
+288
+397
+485
+543
+556
+601
+605
+652
+653
+663
+666
+697
+699
+705
+745
+774
+815
+816
+845
+848
+951
+977
+1006
+1008
+1018
+1056
+1066
+1079
+1102
+1128
+1133
+1188
+1193
+1194
+1266
+1271
+1372
+1382
+1405
+1426
+1430
+1441
+1477
+1502
+1518
+1606
+1621
+1642
+1658
+1716
+1722
+1734
+1750
+1807
+1880
+1882
+1936
+1951
+1970
+1977
+1983
+2086
+2112
+2146
+2152
+2217
+2304
+2321
+2404
+2526
+2554
+2563
+2647
+2675
+2732
+2733
+2827
+2839
+2854
+2865
+2872
+2880
+2886
+2893
+2915
+2973
+2993
+3019
+3020
+3044
+3047
+3049
+3117
+3167
+3197
+3201
+3282
+3311
+3315
+3344
+3345
+3378
+3425
+3477
+3497
+3514
+3525
+3531
+3587
+3637
+3650
+3657
+3686
+3720
+3732
+3798
+3802
+3823
+3847
+3971
+4007
+4059
+4072
+4087
+4099
+4124
+4126
+4156
+4195
+4197
+4241
+4275
+4321
+4333
+4352
+4356
+4368
+4377
+4428
+4440
+4497
+4509
+4513
+4526
+4528
+4565
+4570
+4596
+4633
+4677
+4696
+4743
+4759
+4778
+4835
+4976
+5032
+5058
+5061
+5066
+5140
+5145
+5177
+5197
+5219
+5226
+5228
+5240
+5289
+5292
+5385
+5433
+5445
+5448
+5465
+5488
+5549
+5553
+5609
+5638
+5666
+5683
+5711
+5729
+5760
+5793
+5819
+5837
+5855
+5858
+5961
+5966
+6048
+6197
+6199
+6201
+6206
+6215
+6220
+6264
+6278
+6280
+6305
+6388
+6411
+6466
+6490
+6509
+6523
+6529
+6625
+6754
+6818
+6886
+6890
+6893
+6902
+6912
+6942
+7067
+7141
+7144
+7214
+7217
+7278
+7312
+7320
+7329
+7342
+7345
+7369
+7408
+7428
+7463
+7556
+7557
+7582
+7613
+7621
+7624
+7647
+7671
+7679
+7734
+7736
+7747
+7750
+7777
+7851
+7854
+7883
+7889
+7902
+7985
+7999
+8070
+8087
+8096
+8100
+8128
+8180
+8195
+8367
+8377
+8465
+8497
+8508
+8528
+8538
+8581
+8657
+8692
+8742
+8784
+8839
+8861
+8912
+8970
+8982
+8987
+9103
+9155
+9180
+9248
+9284
+9300
+9357
+9382
+9414
+9450
+9463
+9493
+9522
+9543
+9563
+9630
+9643
+9653
+9693
+9747
+9787
+9847
+9851
+9892
+9913
+9929
+9965
+10026
+10027
+10055
+10154
+10189
+10243
+10297
+10337
+10346
+10347
+10377
+10403
+10483
+10518
+10540
+10559
+10567
+10568
+10580
+10606
+10615
+10618
+10645
+10685
+10707
+10710
+10807
+10837
+10856
+10873
+10989
+11046
+11054
+11132
+11163
+11218
+11243
+11255
+11265
+11292
+11306
+11307
+11310
+11343
+11349
+11407
+11411
+11422
+11427
+11431
+11439
+11496
+11644
+11662
+11690
+11692
+11725
+11743
+11767
+11812
+11867
+11871
+11897
+11975
+12001
+12046
+12076
+12119
+12158
+12216
+12252
+12261
+12264
+12293
+12296
+12306
+12357
+12358
+12371
+12415
+12422
+12472
+12497
+12499
+12538
+12540
+12544
+12569
+12645
+12647
+12652
+12699
+12727
+12750
+12832
+12849
+12873
+12889
+12902
+12996
+13029
+13065
+13073
+13075
+13079
+13268
+13338
+13372
+13529
+13530
+13537
+13623
+13626
+13637
+13644
+13646
+13681
+13778
+13782
+13805
+13846
+13853
+13881
+13914
+13961
+13975
+13979
+14011
+14135
+14143
+14144
+14161
+14170
+14207
+14212
+14215
+14260
+14311
+14368
+14373
+14400
+14509
+14523
+14566
+14594
+14628
+14629
+14633
+14649
+14652
+14705
+14709
+14732
+14734
+14802
+14834
+14865
+14883
+14933
+14965
+15003
+15100
+15159
+15178
+15272
+15289
+15308
+15319
+15327
+15353
+15357
+15363
+15408
+15429
+15438
+15469
+15485
+15495
+15501
+15524
+15530
+15551
+15598
+15613
+15614
+15631
+15646
+15647
+15661
+15679
+15684
+15758
+15775
+15826
+15838
+15840
+15931
+15940
+15969
+15976
+16003
+16037
+16045
+16116
+16200
+16233
+16247
+16339
+16340
+16345
+16361
+16400
+16408
+16430
+16468
+16474
+16500
+16521
+16565
+16569
+16584
+16613
+16645
+16662
+16671
+16719
+16724
+16760
+16764
+16805
+16849
+16893
+16896
+16954
+16979
+17023
+17026
+17034
+17038
+17049
+17054
+17061
+17073
+17074
+17133
+17163
+17176
+17177
+17217
+17237
+17246
+17298
+17312
+17324
+17337
+17365
+17415
+17442
+17449
+17576
+17578
+17581
+17588
+17589
+17591
+17593
+17605
+17661
+17688
+17689
+17695
+17697
+17703
+17736
+17746
+17758
+17788
+17798
+17828
+17841
+17884
+17898
+17924
+17956
+17960
+18001
+18013
+18025
+18052
+18097
+18106
+18158
+18211
+18223
+18240
+18261
+18266
+18297
+18325
+18329
+18335
+18340
+18351
+18433
+18462
+18466
+18524
+18569
+18581
+18631
+18696
+18748
+18766
+18787
+18793
+18950
+18961
+19001
+19008
+19011
+19154
+19177
+19217
+19255
+19286
+19320
+19333
+19360
+19403
+19407
+19419
+19464
+19499
+19510
+19519
+19555
+19564
+19605
+19610
+19689
+19699
+19705
+19707
+19725
+19732
+19741
+19774
+19799
+19838
+19877
+19903
+19940
+19945
+19952
+19973
+19987
+20024
+20086
+20111
+20114
+20174
+20193
+20201
+20245
+20299
+20329
+20439
+20485
+20534
+20562
+20575
+20578
+20601
+20604
+20605
+20648
+20658
+20665
+20677
+20693
+20697
+20699
+20791
+20794
+20808
+20876
+20890
+20906
+20914
+20990
+21065
+21128
+21144
+21151
+21156
+21175
+21199
+21204
+21207
+21225
+21236
+21241
+21342
+21351
+21429
+21533
+21550
+21622
+21676
+21727
+21764
+21785
+21822
+21830
+21845
+21853
+21867
+21909
+21910
+21923
+21924
+21937
+21948
+21955
+21962
+22008
+22017
+22026
+22037
+22072
+22075
+22135
+22138
+22160
+22167
+22190
+22287
+22375
+22440
+22457
+22460
+22471
+22481
+22484
+22488
+22515
+22553
+22679
+22703
+22714
+22730
+22735
+22752
+22768
+22809
+22813
+22817
+22846
+22902
+22910
+22944
+22986
+23026
+23053
+23065
+23088
+23117
+23124
+23126
+23132
+23142
+23165
+23172
+23223
+23264
+23280
+23322
+23335
+23439
+23453
+23455
+23474
+23501
+23518
+23580
+23589
+23608
+23614
+23641
+23649
+23660
+23698
+23728
+23766
+23809
+23859
+23874
+23902
+23946
+24040
+24105
+24132
+24137
+24151
+24153
+24157
+24171
+24271
+24281
+24296
+24303
+24308
+24328
+24332
+24338
+24402
+24440
+24453
+24466
+24504
+24531
+24543
+24547
+24556
+24562
+24610
+24649
+24660
+24693
+24706
+24745
+24834
+24948
+24963
+25056
+25057
+25083
+25093
+25120
+25150
+25161
+25197
+25219
+25220
+25253
+25257
+25290
+25327
+25332
+25344
+25387
+25390
+25422
+25453
+25481
+25489
+25587
+25599
+25600
+25622
+25681
+25686
+25702
+25708
+25740
+25776
+25870
+25918
+25973
+25978
+25986
+25987
+26033
+26038
+26041
+26087
+26113
+26155
+26162
+26184
+26235
+26299
+26301
+26318
+26364
+26383
+26430
+26511
+26528
+26561
+26618
+26653
+26688
+26697
+26778
+26940
+26951
+27023
+27029
+27037
+27046
+27051
+27118
+27244
+27252
+27258
+27272
+27283
+27303
+27381
+27392
+27403
+27422
+27437
+27440
+27476
+27493
+27494
+27501
+27506
+27550
+27559
+27571
+27581
+27596
+27604
+27612
+27665
+27687
+27701
+27711
+27732
+27759
+27766
+27772
+27797
+27813
+27854
+27864
+27865
+27879
+27894
+27907
+27958
+27963
+27969
+28003
+28027
+28032
+28051
+28058
+28079
+28093
+28120
+28132
+28194
+28227
+28324
+28328
+28331
+28360
+28373
+28419
+28431
+28436
+28451
+28467
+28471
+28527
+28541
+28588
+28640
+28649
+28662
+28670
+28678
+28722
+28768
+28780
+28835
+28863
+28879
+28885
+28928
+28948
+28954
+28963
+28969
+29020
+29065
+29077
+29105
+29117
+29143
+29166
+29172
+29299
+29302
+29342
+29357
+29378
+29410
+29411
+29414
+29415
+29447
+29473
+29488
+29499
+29505
+29533
+29537
+29601
+29637
+29650
+29667
+29671
+29681
+29686
+29708
+29721
+29749
+29755
+29771
+29853
+29886
+29894
+29919
+29928
+29990
+30008
+30064
+30067
+30107
+30150
+30160
+30164
+30186
+30195
+30219
+30243
+30282
+30314
+30324
+30389
+30418
+30497
+30550
+30592
+30615
+30624
+30640
+30650
+30695
+30720
+30741
+30750
+30751
+30767
+30830
+30856
+30885
+30901
+30907
+30953
+30985
+31005
+31027
+31034
+31045
+31057
+31071
+31109
+31119
+31227
+31230
+31250
+31303
+31320
+31371
+31401
+31440
+31447
+31464
+31478
+31487
+31494
+31525
+31553
+31554
+31558
+31572
+31588
+31639
+31641
+31683
+31698
+31704
+31708
+31717
+31722
+31781
+31786
+31788
+31791
+31803
+31850
+31853
+31862
+31886
+31901
+31944
+32020
+32048
+32052
+32073
+32094
+32116
+32147
+32180
+32212
+32218
+32256
+32270
+32305
+32411
+32414
+32430
+32465
+32484
+32534
+32584
+32589
+32608
+32612
+32613
+32615
+32641
+32674
+32697
+32708
+32757
+32763
+32796
+32824
+32861
+32877
+32944
+32945
+32946
+32984
+33004
+33012
+33029
+33050
+33090
+33096
+33097
+33124
+33139
+33161
+33170
+33173
+33179
+33191
+33293
+33367
+33370
+33371
+33373
+33399
+33415
+33436
+33440
+33443
+33488
+33551
+33563
+33564
+33629
+33643
+33664
+33685
+33696
+33714
+33722
+33728
+33764
+33809
+33868
+33883
+33913
+33942
+33956
+33994
+34081
+34089
+34091
+34098
+34178
+34207
+34269
+34287
+34348
+34392
+34445
+34447
+34455
+34529
+34579
+34591
+34643
+34659
+34692
+34729
+34758
+34836
+34857
+34862
+34883
+34930
+34942
+34957
+34963
+35003
+35089
+35180
+35187
+35209
+35220
+35239
+35247
+35253
+35263
+35380
+35393
+35394
+35408
+35452
+35485
+35486
+35557
+35578
+35639
+35663
+35688
+35746
+35832
+35862
+35890
+35903
+35917
+35929
+35946
+35984
+36060
+36084
+36090
+36124
+36135
+36151
+36197
+36249
+36269
+36303
+36364
+36377
+36398
+36402
+36418
+36421
+36435
+36499
+36511
+36521
+36544
+36556
+36601
+36627
+36640
+36660
+36673
+36676
+36787
+36790
+36797
+36821
+36840
+36901
+36921
+36934
+37006
+37041
+37051
+37112
+37160
+37167
+37213
+37231
+37242
+37274
+37313
+37332
+37391
+37416
+37522
+37594
+37621
+37664
+37699
+37731
+37915
+37968
+38030
+38070
+38117
+38128
+38135
+38172
+38184
+38224
+38277
+38295
+38311
+38428
+38464
+38529
+38549
+38599
+38623
+38673
+38681
+38713
+38722
+38726
+38762
+38867
+38872
+38944
+38947
+39015
+39023
+39028
+39043
+39068
+39080
+39097
+39118
+39171
+39197
+39236
+39254
+39271
+39277
+39280
+39336
+39338
+39340
+39341
+39358
+39364
+39497
+39503
+39537
+39541
+39559
+39560
+39562
+39596
+39600
+39613
+39623
+39656
+39670
+39781
+39810
+39832
+39861
+39875
+39892
+39918
+39919
+40008
+40016
+40082
+40091
+40095
+40164
+40213
+40234
+40274
+40279
+40324
+40332
+40341
+40349
+40365
+40438
+40446
+40482
+40501
+40510
+40516
+40541
+40544
+40545
+40574
+40617
+40659
+40668
+40742
+40754
+40758
+40764
+40765
+40795
+40858
+40901
+40985
+40986
+41080
+41112
+41121
+41136
+41196
+41199
+41219
+41233
+41246
+41278
+41376
+41401
+41409
+41434
+41470
+41492
+41502
+41517
+41571
+41572
+41608
+41648
+41699
+41773
+41779
+41801
+41837
+41843
+41849
+41855
+41873
+41881
+41901
+41924
+41926
+41935
+41962
+42008
+42062
+42069
+42072
+42094
+42097
+42104
+42112
+42117
+42137
+42147
+42170
+42185
+42224
+42237
+42250
+42254
+42257
+42276
+42282
+42298
+42321
+42351
+42372
+42378
+42420
+42446
+42453
+42466
+42470
+42502
+42514
+42518
+42527
+42662
+42721
+42727
+42743
+42794
+42840
+42843
+42871
+42872
+42897
+42950
+42956
+42967
+42969
+42975
+42995
+43005
+43008
+43046
+43052
+43091
+43103
+43124
+43198
+43225
+43228
+43385
+43394
+43402
+43405
+43408
+43423
+43503
+43529
+43557
+43647
+43656
+43704
+43706
+43714
+43745
+43748
+43759
+43812
+43927
+43950
+43997
+43998
+44016
+44018
+44025
+44060
+44066
+44099
+44128
+44149
+44150
+44169
+44184
+44198
+44254
+44272
+44293
+44310
+44352
+44389
+44399
+44400
+44442
+44451
+44470
+44474
+44522
+44569
+44590
+44713
+44738
+44787
+44823
+44829
+44845
+44895
+44918
+44975
+45024
+45121
+45148
+45154
+45179
+45208
+45210
+45215
+45218
+45220
+45235
+45265
+45282
+45283
+45285
+45286
+45303
+45351
+45359
+45396
+45407
+45414
+45472
+45519
+45522
+45564
+45621
+45641
+45660
+45678
+45695
+45696
+45710
+45780
+45800
+45823
+45828
+45862
+45947
+45964
+46001
+46050
+46084
+46113
+46132
+46146
+46198
+46221
+46234
+46236
+46256
+46272
+46298
+46325
+46337
+46347
+46374
+46386
+46388
+46437
+46491
+46560
+46561
+46589
+46600
+46656
+46660
+46664
+46673
+46690
+46700
+46808
+46809
+46828
+46918
+46963
+46979
+46984
+47005
+47088
+47097
+47100
+47143
+47147
+47261
+47320
+47369
+47450
+47503
+47533
+47538
+47576
+47601
+47608
+47618
+47621
+47624
+47659
+47681
+47698
+47708
+47745
+47817
+47826
+47879
+47883
+47917
+47937
+47957
+48000
+48023
+48076
+48099
+48130
+48133
+48281
+48298
+48321
+48349
+48351
+48353
+48358
+48371
+48426
+48455
+48522
+48526
+48544
+48573
+48606
+48609
+48646
+48667
+48699
+48701
+48740
+48773
+48777
+48785
+48847
+48886
+48940
+48986
+49029
+49054
+49100
+49121
+49137
+49157
+49191
+49222
+49291
+49315
+49347
+49374
+49376
+49381
+49407
+49427
+49481
+49497
+49624
+49785
+49791
+49835
+49875
+49877
+49981
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
index f361341f7c..2a8a2b9b59 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
@@ -52,18 +52,22 @@ class ResultsWriter : public ImagenetModelEvaluator::Observer {
explicit ResultsWriter(std::unique_ptr<CSVWriter> writer)
: writer_(std::move(writer)) {}
- void OnEvaluationStart(int total_number_of_images) override {}
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {}
void OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats,
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) override;
private:
- std::unique_ptr<CSVWriter> writer_;
+ std::unique_ptr<CSVWriter> writer_ GUARDED_BY(mu_);
+ mutex mu_;
};
void ResultsWriter::OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) {
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
+ mutex_lock lock(mu_);
TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats)));
writer_->Flush();
}
@@ -71,33 +75,40 @@ void ResultsWriter::OnSingleImageEvaluationComplete(
// Logs results to standard output with `kLogDelayUs` microseconds.
class ResultsLogger : public ImagenetModelEvaluator::Observer {
public:
- void OnEvaluationStart(int total_number_of_images) override;
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override;
void OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats,
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) override;
private:
- int total_num_images_ = 0;
- uint64 last_logged_time_us_ = 0;
+ uint64_t last_logged_time_us_ GUARDED_BY(mu_) = 0;
+ int total_num_images_ GUARDED_BY(mu_);
static constexpr int kLogDelayUs = 500 * 1000;
+ mutex mu_;
};
-void ResultsLogger::OnEvaluationStart(int total_number_of_images) {
- total_num_images_ = total_number_of_images;
- LOG(ERROR) << "Starting model evaluation: " << total_num_images_;
+void ResultsLogger::OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) {
+ int total_num_images = 0;
+ for (const auto& kv : shard_id_image_count_map) {
+ total_num_images += kv.second;
+ }
+ LOG(ERROR) << "Starting model evaluation: " << total_num_images;
+ mutex_lock lock(mu_);
+ total_num_images_ = total_num_images;
}
void ResultsLogger::OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats, const string& image) {
- int num_evaluated = stats.number_of_images;
-
- double current_percent = num_evaluated * 100.0 / total_num_images_;
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
auto now_us = Env::Default()->NowMicros();
-
+ int num_evaluated = stats.number_of_images;
+ mutex_lock lock(mu_);
if ((now_us - last_logged_time_us_) >= kLogDelayUs) {
last_logged_time_us_ = now_us;
-
+ double current_percent = num_evaluated * 100.0 / total_num_images_;
LOG(ERROR) << "Evaluated " << num_evaluated << "/" << total_num_images_
<< " images, " << std::setprecision(2) << std::fixed
<< current_percent << "%";
@@ -108,15 +119,20 @@ int Main(int argc, char* argv[]) {
// TODO(shashishekhar): Make this binary configurable and model
// agnostic.
string output_file_path;
+ int num_threads = 4;
std::vector<Flag> flag_list = {
Flag("output_file_path", &output_file_path, "Path to output file."),
+ Flag("num_threads", &num_threads, "Number of threads."),
};
Flags::Parse(&argc, argv, flag_list);
std::unique_ptr<ImagenetModelEvaluator> evaluator;
CHECK(!output_file_path.empty()) << "Invalid output file path.";
- TF_CHECK_OK(ImagenetModelEvaluator::Create(argc, argv, &evaluator));
+ CHECK(num_threads > 0) << "Invalid number of threads.";
+
+ TF_CHECK_OK(
+ ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator));
std::ofstream output_stream(output_file_path, std::ios::out);
CHECK(output_stream) << "Unable to open output file path: '"
@@ -136,6 +152,7 @@ int Main(int argc, char* argv[]) {
ResultsLogger logger;
evaluator->AddObserver(&results_writer);
evaluator->AddObserver(&logger);
+ LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
TF_CHECK_OK(evaluator->EvaluateModel());
return 0;
}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
index a88a4a0fce..e0e5dcb264 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -29,7 +29,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -57,6 +60,21 @@ std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
return result;
}
+template <typename T>
+std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
+ CHECK_GT(n, 0);
+ std::vector<std::vector<T>> vecs(n);
+ int input_index = 0;
+ int vec_index = 0;
+ while (input_index < v.size()) {
+ vecs[vec_index].push_back(v[input_index]);
+ vec_index = (vec_index + 1) % n;
+ input_index++;
+ }
+ CHECK_EQ(vecs.size(), n);
+ return vecs;
+}
+
// File pattern for imagenet files.
const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]";
@@ -65,8 +83,36 @@ const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]";
namespace tensorflow {
namespace metrics {
+class CompositeObserver : public ImagenetModelEvaluator::Observer {
+ public:
+ explicit CompositeObserver(const std::vector<Observer*>& observers)
+ : observers_(observers) {}
+
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnEvaluationStart(shard_id_image_count_map);
+ }
+ }
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnSingleImageEvaluationComplete(shard_id, stats, image);
+ }
+ }
+
+ private:
+ const std::vector<ImagenetModelEvaluator::Observer*>& observers_
+ GUARDED_BY(mu_);
+ mutex mu_;
+};
+
/*static*/ Status ImagenetModelEvaluator::Create(
- int argc, char* argv[],
+ int argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
Params params;
const std::vector<Flag> flag_list = {
@@ -82,8 +128,12 @@ namespace metrics {
Flag("num_images", &params.number_of_images,
"Number of examples to evaluate, pass 0 for all "
"examples. Default: 100"),
- tensorflow::Flag("model_file", &params.model_file_path,
- "Path to test tflite model file."),
+ Flag("blacklist_file_path", &params.blacklist_file_path,
+ "Path to blacklist file (optional)."
+ "Path to blacklist file where each line is a single integer that is "
+ "equal to number of blacklisted image."),
+ Flag("model_file", &params.model_file_path,
+ "Path to test tflite model file."),
};
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
if (!parse_result)
@@ -100,6 +150,12 @@ namespace metrics {
Env::Default()->FileExists(params.model_output_labels_path),
"Invalid model output labels path.");
+ if (!params.blacklist_file_path.empty()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.blacklist_file_path),
+ "Invalid blacklist path.");
+ }
+
if (params.number_of_images < 0) {
return errors::InvalidArgument("Invalid: num_examples");
}
@@ -109,28 +165,30 @@ namespace metrics {
utils::GetTFliteModelInfo(params.model_file_path, &model_info),
"Invalid TFLite model.");
- *model_evaluator =
- absl::make_unique<ImagenetModelEvaluator>(model_info, params);
+ *model_evaluator = absl::make_unique<ImagenetModelEvaluator>(
+ model_info, params, num_threads);
return Status::OK();
}
-Status ImagenetModelEvaluator::EvaluateModel() {
- if (model_info_.input_shapes.size() != 1) {
- return errors::InvalidArgument("Invalid input shape");
- }
-
- const TensorShape& input_shape = model_info_.input_shapes[0];
- // Input should be of the shape {1, height, width, 3}
- if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
- return errors::InvalidArgument("Invalid input shape for the model.");
- }
-
+struct ImageLabel {
+ string image;
+ string label;
+};
+
+Status EvaluateModelForShard(const uint64_t shard_id,
+ const std::vector<ImageLabel>& image_labels,
+ const std::vector<string>& model_labels,
+ const utils::ModelInfo& model_info,
+ const ImagenetModelEvaluator::Params& params,
+ ImagenetModelEvaluator::Observer* observer,
+ ImagenetTopKAccuracy* eval) {
+ const TensorShape& input_shape = model_info.input_shapes[0];
const int image_height = input_shape.dim_size(1);
const int image_width = input_shape.dim_size(2);
- const bool is_quantized = (model_info_.input_types[0] == DT_UINT8);
+ const bool is_quantized = (model_info.input_types[0] == DT_UINT8);
RunTFLiteModelStage::Params tfl_model_params;
- tfl_model_params.model_file_path = params_.model_file_path;
+ tfl_model_params.model_file_path = params.model_file_path;
if (is_quantized) {
tfl_model_params.input_type = {DT_UINT8};
tfl_model_params.output_type = {DT_UINT8};
@@ -144,29 +202,77 @@ Status ImagenetModelEvaluator::EvaluateModel() {
InceptionPreprocessingStage inc(image_height, image_width, is_quantized);
RunTFLiteModelStage tfl_model_stage(tfl_model_params);
EvalPipelineBuilder builder;
- std::vector<string> model_labels;
- TF_RETURN_IF_ERROR(
- utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
- if (model_labels.size() != 1001) {
- return errors::InvalidArgument("Invalid number of labels: ",
- model_labels.size());
- }
- ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
std::unique_ptr<EvalPipeline> eval_pipeline;
auto build_status = builder.WithInputStage(&reader)
.WithPreprocessingStage(&inc)
.WithRunModelStage(&tfl_model_stage)
- .WithAccuracyEval(&eval)
+ .WithAccuracyEval(eval)
.WithInput("input_file", DT_STRING)
.Build(root, &eval_pipeline);
TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status,
"Failure while building eval pipeline.");
-
std::unique_ptr<Session> session(NewSession(SessionOptions()));
TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session)));
+
+ for (const auto& image_label : image_labels) {
+ TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_label.image),
+ CreateStringTensor(image_label.label)));
+ observer->OnSingleImageEvaluationComplete(
+ shard_id, eval->GetTopKAccuracySoFar(), image_label.image);
+ }
+ return Status::OK();
+}
+
+Status FilterBlackListedImages(const string& blacklist_file_path,
+ std::vector<ImageLabel>* image_labels) {
+ if (!blacklist_file_path.empty()) {
+ std::vector<string> lines;
+ TF_RETURN_IF_ERROR(utils::ReadFileLines(blacklist_file_path, &lines));
+ std::vector<int> blacklist_ids;
+ blacklist_ids.reserve(lines.size());
+ // Populate blacklist_ids with indices of images.
+ std::transform(lines.begin(), lines.end(),
+ std::back_insert_iterator(blacklist_ids),
+ [](const string& val) { return std::stoi(val) - 1; });
+
+ std::vector<ImageLabel> filtered_images;
+ std::sort(blacklist_ids.begin(), blacklist_ids.end());
+ const size_t size_post_filtering =
+ image_labels->size() - blacklist_ids.size();
+ filtered_images.reserve(size_post_filtering);
+ int blacklist_index = 0;
+ for (int image_index = 0; image_index < image_labels->size();
+ image_index++) {
+ if (blacklist_index < blacklist_ids.size() &&
+ blacklist_ids[blacklist_index] == image_index) {
+ blacklist_index++;
+ continue;
+ }
+ filtered_images.push_back((*image_labels)[image_index]);
+ }
+
+ if (filtered_images.size() != size_post_filtering) {
+ return errors::Internal("Invalid number of filtered images");
+ }
+ *image_labels = filtered_images;
+ }
+ return Status::OK();
+}
+
+Status ImagenetModelEvaluator::EvaluateModel() const {
+ if (model_info_.input_shapes.size() != 1) {
+ return errors::InvalidArgument("Invalid input shape");
+ }
+
+ const TensorShape& input_shape = model_info_.input_shapes[0];
+ // Input should be of the shape {1, height, width, 3}
+ if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
+ return errors::InvalidArgument("Invalid input shape for the model.");
+ }
+
string data_path =
StripTrailingSlashes(params_.ground_truth_images_path) + "/";
@@ -174,31 +280,69 @@ Status ImagenetModelEvaluator::EvaluateModel() {
std::vector<string> image_files;
TF_CHECK_OK(
Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files));
- std::vector<string> image_labels;
- TF_CHECK_OK(
- utils::ReadFileLines(params_.ground_truth_labels_path, &image_labels));
- CHECK_EQ(image_files.size(), image_labels.size());
+ std::vector<string> ground_truth_image_labels;
+ TF_CHECK_OK(utils::ReadFileLines(params_.ground_truth_labels_path,
+ &ground_truth_image_labels));
+ CHECK_EQ(image_files.size(), ground_truth_image_labels.size());
// Process files in filename sorted order.
std::sort(image_files.begin(), image_files.end());
+
+ std::vector<ImageLabel> image_labels;
+ image_labels.reserve(image_files.size());
+ for (int i = 0; i < image_files.size(); i++) {
+ image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
+ }
+
+ // Filter any blacklisted images.
+ TF_CHECK_OK(
+ FilterBlackListedImages(params_.blacklist_file_path, &image_labels));
+
if (params_.number_of_images > 0) {
- image_files = GetFirstN(image_files, params_.number_of_images);
image_labels = GetFirstN(image_labels, params_.number_of_images);
}
- for (Observer* observer : observers_) {
- observer->OnEvaluationStart(image_files.size());
+ std::vector<string> model_labels;
+ TF_RETURN_IF_ERROR(
+ utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
+ if (model_labels.size() != 1001) {
+ return errors::InvalidArgument("Invalid number of labels: ",
+ model_labels.size());
}
- for (int i = 0; i < image_files.size(); i++) {
- TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_files[i]),
- CreateStringTensor(image_labels[i])));
- auto stats = eval.GetTopKAccuracySoFar();
+ ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
- for (Observer* observer : observers_) {
- observer->OnSingleImageEvaluationComplete(stats, image_files[i]);
- }
+ auto img_labels = Split(image_labels, num_threads_);
+
+ BlockingCounter counter(num_threads_);
+
+ CompositeObserver observer(observers_);
+
+ ::tensorflow::thread::ThreadPool pool(Env::Default(), "evaluation_pool",
+ num_threads_);
+ std::unordered_map<uint64_t, int> shard_id_image_count_map;
+ std::vector<std::function<void()>> thread_funcs;
+ thread_funcs.reserve(num_threads_);
+ for (int i = 0; i < num_threads_; i++) {
+ const auto& image_label = img_labels[i];
+ const uint64_t shard_id = i + 1;
+ shard_id_image_count_map[shard_id] = image_label.size();
+ auto func = [&]() {
+ TF_CHECK_OK(EvaluateModelForShard(shard_id, image_label, model_labels,
+ model_info_, params_, &observer,
+ &eval));
+ counter.DecrementCount();
+ };
+ thread_funcs.push_back(func);
}
+
+ observer.OnEvaluationStart(shard_id_image_count_map);
+ for (const auto& func : thread_funcs) {
+ pool.Schedule(func);
+ }
+
+ counter.Wait();
+
return Status::OK();
}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
index 5f42b2a50e..97e4232b35 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
@@ -56,6 +56,13 @@ class ImagenetModelEvaluator {
// Path to the model file.
string model_file_path;
+ // Path to black list file. 1762 images were blacklisted from
+ // original ILSVRC dataset. This black list file is present in
+ // ILSVRC2014 devkit. Please refer to readme.txt of the ILSVRC2014
+ // devkit for details.
+ // This file is a list of image indices in a sorted order.
+ string blacklist_file_path;
+
// The maximum number of images to calculate accuracy.
// 0 means all images, a positive number means only the specified
// number of images.
@@ -66,6 +73,7 @@ class ImagenetModelEvaluator {
};
// An evaluation observer.
+ // Observers can be called from multiple threads and need to be thread safe.
class Observer {
public:
Observer() = default;
@@ -76,38 +84,41 @@ class ImagenetModelEvaluator {
Observer& operator=(const Observer&&) = delete;
// Called on start of evaluation.
- virtual void OnEvaluationStart(int total_number_of_images) = 0;
+ // `shard_id_image_count_map` map from shard id to image count.
+ virtual void OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) = 0;
// Called when evaluation was complete for `image`.
virtual void OnSingleImageEvaluationComplete(
- const ImagenetTopKAccuracy::AccuracyStats& stats,
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
const string& image) = 0;
virtual ~Observer() = default;
};
ImagenetModelEvaluator(const utils::ModelInfo& model_info,
- const Params& params)
- : model_info_(model_info), params_(params) {}
+ const Params& params, const int num_threads)
+ : model_info_(model_info), params_(params), num_threads_(num_threads) {}
// Factory method to create the evaluator by parsing command line arguments.
- static Status Create(int argc, char* argv[],
+ static Status Create(int argc, char* argv[], int num_threads,
std::unique_ptr<ImagenetModelEvaluator>* evaluator);
// Adds an observer that can observe evaluation events..
void AddObserver(Observer* observer) { observers_.push_back(observer); }
- const Params& params() { return params_; }
+ const Params& params() const { return params_; }
// Evaluates the provided model over the dataset.
- Status EvaluateModel();
+ Status EvaluateModel() const;
private:
- std::vector<Observer*> observers_;
const utils::ModelInfo model_info_;
const Params params_;
+ const int num_threads_;
+ std::vector<Observer*> observers_;
};
} // namespace metrics
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
index d46075d234..c75baa82b1 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
@@ -77,26 +77,33 @@ Status ImagenetTopKAccuracy::ComputeEval(
CHECK_EQ(kNumCategories, probabilities.size());
std::vector<int> topK = GetTopK(probabilities, k_);
int ground_truth_index = GroundTruthIndex(ground_truth_label);
- for (size_t i = 0; i < topK.size(); ++i) {
- if (ground_truth_index == topK[i]) {
- for (size_t j = i; j < topK.size(); j++) {
- accuracy_counts_[j] += 1;
- }
- break;
- }
- }
- num_samples_++;
+ UpdateSamples(topK, ground_truth_index);
return Status::OK();
}
const ImagenetTopKAccuracy::AccuracyStats
ImagenetTopKAccuracy::GetTopKAccuracySoFar() const {
+ mutex_lock lock(mu_);
AccuracyStats stats;
stats.number_of_images = num_samples_;
stats.topk_counts = accuracy_counts_;
return stats;
}
+void ImagenetTopKAccuracy::UpdateSamples(const std::vector<int>& counts,
+ int ground_truth_index) {
+ mutex_lock lock(mu_);
+ for (size_t i = 0; i < counts.size(); ++i) {
+ if (ground_truth_index == counts[i]) {
+ for (size_t j = i; j < counts.size(); j++) {
+ accuracy_counts_[j] += 1;
+ }
+ break;
+ }
+ }
+ num_samples_++;
+}
+
int ImagenetTopKAccuracy::GroundTruthIndex(const string& label) const {
auto index = std::find(ground_truth_labels_.cbegin(),
ground_truth_labels_.cend(), label);
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
index 5a575ff244..cad646a30c 100644
--- a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace metrics {
@@ -69,12 +70,14 @@ class ImagenetTopKAccuracy : public AccuracyEval {
private:
int GroundTruthIndex(const string& label) const;
- std::vector<string> ground_truth_labels_;
+ void UpdateSamples(const std::vector<int>& counts, int ground_truth_index);
+ const std::vector<string> ground_truth_labels_;
const int k_;
- std::vector<int> accuracy_counts_;
- int num_samples_;
+ std::vector<int> accuracy_counts_ GUARDED_BY(mu_);
+ int num_samples_ GUARDED_BY(mu_);
+ mutable mutex mu_;
};
} // namespace metrics
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_TOPK_EVAL_H_