aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/generate_examples.py
blob: 6204471e52a5f0fee3d46388533162be52011dc1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Generate a series of TensorFlow graphs that become tflite test cases.

Usage:

generate_examples <output directory> zipped

bazel run //tensorflow/contrib/lite/testing:generate_examples
    third_party/tensorflow/contrib/lite/testing/generated_examples zipped
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import itertools
import os
import re
import sys
import tempfile
import traceback
import zipfile
import numpy as np
from six import StringIO

# TODO(aselle): Disable GPU for now
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# pylint: disable=g-import-not-at-top
import tensorflow as tf
from google.protobuf import text_format
# TODO(aselle): switch to TensorFlow's resource_loader
from tensorflow.contrib.lite.testing import generate_examples_report as report_lib
from tensorflow.python.framework import graph_util as tf_graph_util

parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
parser.add_argument("output_path",
                    help="Directory where the outputs will be go.")
# TODO(ahentz): remove this flag
parser.add_argument("type", help="zipped")
parser.add_argument("--zip_to_output",
                    type=str,
                    help="Particular zip to output.",
                    required=False)
parser.add_argument("--toco",
                    type=str,
                    help="Path to toco tool.",
                    required=True)
parser.add_argument(
    "--known_bugs_are_errors",
    action="store_true",
    help=("If a particular model is affected by a known bug,"
          " count it as a toco error."))
parser.add_argument(
    "--ignore_toco_errors",
    action="store_true",
    help="Raise an exception if any toco error is encountered.")
parser.add_argument(
    "--save_graphdefs",
    action="store_true",
    help="Include intermediate graphdefs in the output zip files.")


RANDOM_SEED = 342
TEST_INPUT_DEPTH = 3


# A map from regular expression to bug number. Any test failure with label
# matching the expression will be considered due to the corresponding bug.
KNOWN_BUGS = {
    # TOCO doesn't support scalars as input.
    r"relu.*input_shape=\[\]": "67587484",
    r"sigmoid.*input_shape=\[\]": "67645668",
    # Concat doesn't work with a single input tensor
    r"concat.*num_tensors=1": "67378344",
    # Transposition in MatMul is not supported.
    r"fully_connected.*transpose_.=True": "67586970",
    # Softmax graphs are too complex.
    r"softmax.*dim=0": "67749831",
    r"softmax.*input_shape=\[1,3,4,3\]": "67749831",
    # SpaceToDepth only supports float32.
    r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
    # BatchToSpaceND doesn't support cropping.
    r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634",
    # BatchToSpaceND only supports 4D tensors.
    r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
    # Div will use floordiv
    r"div.*int32": "72051395"
}


def toco_options(data_types,
                 input_arrays,
                 output_arrays,
                 shapes,
                 drop_control_dependency):
  """Create TOCO options to process a model.

  Args:
    data_types: input and inference types used by TOCO.
    input_arrays: names of the input tensors
    output_arrays: name of the output tensors
    shapes: shapes of the input tensors
    drop_control_dependency: whether to ignore control dependency nodes.

  Returns:
    the options in a string.
  """
  shape_str = ":".join([",".join(str(y) for y in x) for x in shapes])
  inference_type = "FLOAT"
  # TODO(ahentz): if we get multi-input quantization to work we need this
  # to change
  if data_types[0] == "QUANTIZED_UINT8":
    inference_type = "QUANTIZED_UINT8"
  s = (" --input_data_types=%s" % ",".join(data_types) +
       " --inference_type=%s" % inference_type +
       " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" +
       " --input_arrays=%s" % ",".join(input_arrays) +
       " --input_shapes=%s" % shape_str +
       " --output_arrays=%s" % ",".join(output_arrays))
  if drop_control_dependency:
    s += " --drop_control_dependency"
  return s


def write_toco_options(filename,
                       data_types,
                       input_arrays,
                       output_arrays,
                       shapes,
                       drop_control_dependency=False):
  """Create TOCO options to process a model.

  Args:
    filename: Filename to write the options to.
    data_types: input and inference types used by TOCO.
    input_arrays: names of the input tensors
    output_arrays: names of the output tensors
    shapes: shapes of the input tensors
    drop_control_dependency: whether to ignore control dependency nodes.
  """
  with open(filename, "w") as fp:
    fp.write(
        toco_options(
            data_types=data_types,
            input_arrays=input_arrays,
            output_arrays=output_arrays,
            shapes=shapes,
            drop_control_dependency=drop_control_dependency))


def write_examples(fp, examples):
  """Given a list `examples`, write a text format representation.

  The file format is csv like with a simple repeated pattern. We would ike
  to use proto here, but we can't yet due to interfacing with the Android
  team using this format.

  Args:
    fp: File-like object to write to.
    examples: Example dictionary consiting of keys "inputs" and "outputs"
  """

  def write_tensor(fp, x):
    """Write tensor in file format supported by TFLITE example."""
    fp.write("dtype,%s\n" % x.dtype)
    fp.write("shape," + ",".join(map(str, x.shape)) + "\n")
    # Output 9 digits after the point to ensure the precision is good enough.
    values = ["{:.9f}".format(value) for value in list(x.flatten())]
    fp.write("values," + ",".join(values) + "\n")

  fp.write("test_cases,%d\n" % len(examples))
  for example in examples:
    fp.write("inputs,%d\n" % len(example["inputs"]))
    for i in example["inputs"]:
      write_tensor(fp, i)
    fp.write("outputs,%d\n" % len(example["outputs"]))
    for i in example["outputs"]:
      write_tensor(fp, i)


def write_test_cases(fp, model_name, examples):
  """Given a dictionary of `examples`, write a text format representation.

  The file format is protocol-buffer-like, even though we don't use proto due
  to the needs of the Android team.

  Args:
    fp: File-like object to write to.
    model_name: Filename where the model was written to, relative to filename.
    examples: Example dictionary consiting of keys "inputs" and "outputs"
  """

  fp.write("load_model: %s\n" % os.path.basename(model_name))
  for example in examples:
    fp.write("reshape {\n")
    for t in example["inputs"]:
      fp.write("  input: \"" + ",".join(map(str, t.shape)) + "\"\n")
    fp.write("}\n")
    fp.write("invoke {\n")

    for t in example["inputs"]:
      values = ["{:.9f}".format(value) for value in list(t.flatten())]
      fp.write("  input: \"" + ",".join(values) + "\"\n")
    for t in example["outputs"]:
      values = ["{:.9f}".format(value) for value in list(t.flatten())]
      fp.write("  output: \"" + ",".join(values) + "\"\n")
    fp.write("}\n")


_TF_TYPE_INFO = {
    tf.float32: (np.float32, "FLOAT"),
    tf.float16: (np.float16, "FLOAT"),
    tf.int32: (np.int32, "INT32"),
    tf.uint8: (np.uint8, "QUANTIZED_UINT8"),
    tf.int64: (np.int64, "INT64"),
}


def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
  """Build tensor data spreading the range [min_value, max_value)."""

  if dtype in _TF_TYPE_INFO:
    dtype = _TF_TYPE_INFO[dtype][0]

  if dtype in (tf.float32, tf.float16):
    value = (max_value-min_value)*np.random.random_sample(shape)+min_value
  elif dtype in (tf.int32, tf.uint8, tf.int64):
    value = np.random.random_integers(min_value, max_value, shape)
  return value.astype(dtype)


def freeze_graph(session, outputs):
  """Freeze the current graph.

  Args:
    session: Tensorflow sessions containing the graph
    outputs: List of output tensors

  Returns:
    The frozen graph_def.
  """
  return tf_graph_util.convert_variables_to_constants(
      session, session.graph.as_graph_def(), [x.op.name for x in outputs])


def make_control_dep_tests(zip_path):
  """Make a set of tests that use control dependencies."""

  test_parameters = [{
      "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=parameters["input_shape"])
    filter_value = tf.zeros((3, 3, TEST_INPUT_DEPTH, 8), tf.float32)
    assert_op = tf.assert_greater_equal(input_tensor, input_tensor - 1)
    with tf.control_dependencies([assert_op]):
      out = tf.nn.conv2d(input_tensor, filter_value,
                         strides=(1, 1, 1, 1), padding="SAME")
      return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(tf.float32, parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs,
                    drop_control_dependency=True)


def toco_convert(graph_def_str, input_tensors, output_tensors,
                 drop_control_dependency=False):
  """Convert a model's graph def into a tflite model.

  NOTE: this currently shells out to the toco binary, but we would like
  convert to Python API tooling in the future.

  Args:
    graph_def_str: Graph def proto in serialized string format.
    input_tensors: List of input tensor tuples `(name, shape, type)`
    output_tensors: List of output tensors (names)
    drop_control_dependency: whether to ignore control dependency nodes.

  Returns:
    output tflite model, log_txt from conversion
    or None, log_txt if it did not convert properly.
  """
  data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
  opts = toco_options(
      data_types=data_types,
      input_arrays=[x[0] for x in input_tensors],
      shapes=[x[1] for x in input_tensors],
      output_arrays=output_tensors,
      drop_control_dependency=drop_control_dependency)

  with tempfile.NamedTemporaryFile() as graphdef_file, \
       tempfile.NamedTemporaryFile() as output_file, \
       tempfile.NamedTemporaryFile("w+") as stdout_file:
    graphdef_file.write(graph_def_str)
    graphdef_file.flush()

    # TODO(aselle): Switch this to subprocess at some point.
    cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
           (bin_path, graphdef_file.name, output_file.name, opts,
            stdout_file.name))
    exit_code = os.system(cmd)
    log = (
        cmd + "exited with code %d" % exit_code + "\n------------------\n" +
        stdout_file.read())
    return (None if exit_code != 0 else output_file.read()), log


def make_zip_of_tests(zip_path,
                      test_parameters,
                      make_graph,
                      make_test_inputs,
                      drop_control_dependency=False):
  """Helper to make a zip file of a bunch of TensorFlow models.

  This does a cartestian product of the dictionary of test_parameters and
  calls make_graph() for each item in the cartestian product set.
  If the graph is built successfully, then make_test_inputs() is called to
  build expected input/output value pairs. The model is then converted to tflite
  with toco, and the examples are serialized with the tflite model into a zip
  file (2 files per item in the cartesian product set).

  Args:
    zip_path: Path of zip file to write
    test_parameters: Dictionary mapping to lists for each parameter.
      e.g. `{"strides": [[1,3,3,1], [1,2,2,1]], "foo": [1.2, 1.3]}`
    make_graph: function that takes current parameters and returns tuple
      `[input1, input2, ...], [output1, output2, ...]`
    make_test_inputs: function taking `curr_params`, `session`, `input_tensors`,
      `output_tensors` and returns tuple `(input_values, output_values)`.
    drop_control_dependency: whether to ignore control dependency nodes.
  Raises:
    RuntimeError: if there are toco errors that can't be ignored.
  """

  # TODO(aselle): Make this allow multiple inputs outputs.
  archive = zipfile.PyZipFile(zip_path, "w")
  zip_manifest = []
  convert_report = []
  toco_errors = 0
  for parameters in test_parameters:
    keys = parameters.keys()
    for curr in itertools.product(*parameters.values()):
      label = zip_path.replace(".zip", "") + (",".join(
          "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", ""))
      if label[0] == "/":
        label = label[1:]
      param_dict = dict(zip(keys, curr))

      def build_example(label, param_dict_real):
        """Build the model with parameter values set in param_dict_real.

        Args:
          label: Label of the model (i.e. the filename in the zip).
          param_dict_real: Parameter dictionary (arguments to the factories
            make_graph and make_test_inputs)
        Returns:
          (tflite_model_binary, report) where tflite_model_binary is the
          serialized flatbuffer as a string and report is a dictionary with
          keys `toco_log` (log of toco conversion), `tf_log` (log of tf
          conversion), `toco` (a string of success status of the conversion),
          `tf` (a string success status of the conversion).
        """

        np.random.seed(RANDOM_SEED)
        report = {"toco": report_lib.NOTRUN, "tf": report_lib.FAILED}

        # Build graph
        report["tf_log"] = ""
        report["toco_log"] = ""
        tf.reset_default_graph()

        with tf.device("/cpu:0"):
          try:
            inputs, outputs = make_graph(param_dict_real)
          except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
                  ValueError):
            report["tf_log"] += traceback.format_exc()
            return None, report

        sess = tf.Session()
        try:
          baseline_inputs, baseline_outputs = (make_test_inputs(
              param_dict_real, sess, inputs, outputs))
        except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
                ValueError):
          report["tf_log"] += traceback.format_exc()
          return None, report
        report["toco"] = report_lib.FAILED
        report["tf"] = report_lib.SUCCESS

        # Convert graph to toco
        tflite_model_binary, toco_log = toco_convert(
            sess.graph_def.SerializeToString(),
            [(input_tensor.name.split(":")[0], input_tensor.get_shape(),
              input_tensor.dtype) for input_tensor in inputs],
            [out.name.split(":")[0]
             for out in outputs], drop_control_dependency)
        report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None
                          else report_lib.FAILED)
        report["toco_log"] = toco_log

        if FLAGS.save_graphdefs:
          archive.writestr(label + ".pb",
                           text_format.MessageToString(sess.graph_def),
                           zipfile.ZIP_DEFLATED)

        if tflite_model_binary:
          archive.writestr(label + ".bin", tflite_model_binary,
                           zipfile.ZIP_DEFLATED)
          example = {"inputs": baseline_inputs, "outputs": baseline_outputs}

          example_fp = StringIO()
          write_examples(example_fp, [example])
          archive.writestr(label + ".inputs",
                           example_fp.getvalue(), zipfile.ZIP_DEFLATED)

          example_fp2 = StringIO()
          write_test_cases(example_fp2, label + ".bin", [example])
          archive.writestr(label + "_tests.txt",
                           example_fp2.getvalue(), zipfile.ZIP_DEFLATED)

          zip_manifest.append(label + "\n")

        return tflite_model_binary, report

      _, report = build_example(label, param_dict)

      if report["toco"] == report_lib.FAILED:
        ignore_error = False
        if not FLAGS.known_bugs_are_errors:
          for pattern, bug_number in KNOWN_BUGS.items():
            if re.search(pattern, label):
              print("Ignored TOCO error due to bug %s" % bug_number)
              ignore_error = True
        if not ignore_error:
          toco_errors += 1
          print("-----------------\ntoco error!\n%s\n-----------------\n" %
                report["toco_log"])

      convert_report.append((param_dict, report))
  report_io = StringIO()
  report_lib.make_report_table(report_io, zip_path, convert_report)
  archive.writestr("report.html", report_io.getvalue())

  archive.writestr("manifest.txt", "".join(zip_manifest), zipfile.ZIP_DEFLATED)

  # Log statistics of what succeeded
  total_conversions = len(convert_report)
  tf_success = sum(1 for x in convert_report
                   if x[1]["tf"] == report_lib.SUCCESS)
  toco_success = sum(1 for x in convert_report
                     if x[1]["toco"] == report_lib.SUCCESS)
  percent = 0
  if tf_success > 0:
    percent = float(toco_success) / float(tf_success) * 100.
  tf.logging.info(("Archive %s Considered %d graphs, %d TF evaluated graphs "
                   " and %d TOCO converted graphs (%.1f%%"), zip_path,
                  total_conversions, tf_success, toco_success, percent)

  if not FLAGS.ignore_toco_errors and toco_errors > 0:
    raise RuntimeError(
        "Found %d errors while generating toco models" % toco_errors)


def make_pool_tests(pool_op_in):
  """Make a set of tests to do average pooling.

  Args:
    pool_op_in: TensorFlow pooling operation to test  i.e. `tf.nn.avg_pool`.

  Returns:
    A function representing the true generator (after curried pool_op_in).
  """

  pool_op = pool_op_in

  def f(zip_path):
    """Actual function that generates examples.

    Args:
      zip_path: path to write zip to.
    """

    # Chose a set of parameters
    test_parameters = [{
        "ksize": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
        "strides": [[2, 1, 1, 2], [1, 1, 1, 1], [1, 1, 2, 1], [1, 10, 11, 1]],
        # TODO(aselle): should add in a degenerate shape (e.g. [1, 0, 1, 1]).
        "input_shape": [[], [1, 1, 1, 1], [1, 15, 14, 1], [3, 15, 14, 3]],
        "padding": ["SAME", "VALID"],
        "data_format": ["NHWC"],  # TODO(aselle): NCHW  would be good
    }]

    def build_graph(parameters):
      input_tensor = tf.placeholder(
          dtype=tf.float32, name="input", shape=parameters["input_shape"])
      out = pool_op(
          input_tensor,
          ksize=parameters["ksize"],
          strides=parameters["strides"],
          data_format=parameters["data_format"],
          padding=parameters["padding"])
      return [input_tensor], [out]

    def build_inputs(parameters, sess, inputs, outputs):
      input_values = create_tensor_data(tf.float32, parameters["input_shape"])
      return [input_values], sess.run(
          outputs, feed_dict=dict(zip(inputs, [input_values])))

    make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
  return f


def make_relu_tests(zip_path):
  """Make a set of tests to do relu."""

  # Chose a set of parameters
  test_parameters = [{
      "input_shape": [[], [1], [2, 3], [1, 1, 1, 1], [1, 3, 4, 3],
                      [3, 15, 14, 3], [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=parameters["input_shape"])
    out = tf.nn.relu(input_tensor)
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(
        np.float32, parameters["input_shape"], min_value=-4, max_value=10)
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_relu1_tests(zip_path):
  """Make a set of tests to do relu1."""

  # Chose a set of parameters
  test_parameters = [{
      "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
                      [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=parameters["input_shape"])
    # Note that the following is not supported:
    #   out = tf.maximum(-1.0, tf.minimum(input_tensor, 1.0))
    out = tf.minimum(1.0, tf.maximum(input_tensor, -1.0))
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(
        np.float32, parameters["input_shape"], min_value=-3, max_value=10)
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_relu6_tests(zip_path):
  """Make a set of tests to do relu6."""

  # Chose a set of parameters
  test_parameters = [{
      "input_shape": [[], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
                      [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=parameters["input_shape"])
    out = tf.nn.relu(input_tensor)
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(
        np.float32, parameters["input_shape"], min_value=-3, max_value=10)
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


# This function tests various TensorFLow functions that generates Const op,
# including `tf.ones`, `tf.zeros` and random functions.
def make_constant_tests(zip_path):
  """Make a set of tests to do constant ops."""

  test_parameters = [{
      "dtype": [tf.float32, tf.int32],
      "input_shape": [[1], [2], [1, 1, 1, 1], [2, 2, 2, 2]],
  }]

  def build_graph(parameters):
    # Since Toco & Tflite can't have a single constant op in the entire graph,
    # this test adds a zero tesnor with a constant op tensor.
    input1 = tf.placeholder(dtype=parameters["dtype"], name="input1",
                            shape=parameters["input_shape"])
    out = tf.ones(parameters["input_shape"], dtype=parameters["dtype"]) + input1
    return [input1], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input1 = np.zeros(parameters["input_shape"],
                      dtype=_TF_TYPE_INFO[parameters["dtype"]][0])
    return [input1], sess.run(outputs, feed_dict={inputs[0]: input1})

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_binary_op_tests(zip_path, binary_operator):
  """Make a set of tests to do add with and without broadcast."""

  # These parameters are split because we don't support broadcasting.
  test_parameters = [{
      "dtype": [tf.float32, tf.int32],
      "input_shape_1": [[1, 3, 4, 3]],
      "input_shape_2": [[1, 3, 4, 3]],
      "activation": [True]
  }, {
      "dtype": [tf.float32],
      "input_shape_1": [[5]],
      "input_shape_2": [[5]],
      "activation": [False, True]
  }, {
      "dtype": [tf.float32],
      "input_shape_1": [[1, 3, 4, 3]],
      "input_shape_2": [[3]],
      "activation": [True]
  }]

  def build_graph(parameters):
    """Builds the graph given the current parameters."""
    input1 = tf.placeholder(
        dtype=parameters["dtype"],
        name="input1",
        shape=parameters["input_shape_1"])
    input2 = tf.placeholder(
        dtype=parameters["dtype"],
        name="input2",
        shape=parameters["input_shape_2"])
    out = binary_operator(input1, input2)
    if parameters["activation"]:
      out = tf.nn.relu(out)
    return [input1, input2], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    """Builds operand inputs for op."""
    input1 = create_tensor_data(parameters["dtype"],
                                parameters["input_shape_1"])
    input2 = create_tensor_data(parameters["dtype"],
                                parameters["input_shape_2"])
    return [input1, input2], sess.run(
        outputs, feed_dict={
            inputs[0]: input1,
            inputs[1]: input2
        })

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_mean_tests(zip_path):
  """Make a set of tests to do mean."""

  test_parameters = [{
      "input_dtype": [tf.float32, tf.int32, tf.int64],
      "input_shape": [[3, 2, 4]],
      "axis": [
          None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0],
          [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0],
          [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
      ],
      "keep_dims": [True, False],
  }, {
      "input_dtype": [tf.float32, tf.int32, tf.int64],
      "input_shape": [[1, 224, 224, 3]],
      "axis": [
          None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3],
          [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2,
          -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2],
          [2, 2, 3], [-3, -3, -4], [-3, 2, 1]
      ],
      "keep_dims": [True, False],
  }]

  def build_graph(parameters):
    """Build the mean op testing graph."""
    input_tensor = tf.placeholder(
        dtype=parameters["input_dtype"],
        name="input",
        shape=parameters["input_shape"])
    out = tf.reduce_mean(
        input_tensor,
        axis=parameters["axis"],
        keep_dims=parameters["keep_dims"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["input_dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_binary_op_tests_func(binary_operator):
  """Return a function that does a test on a binary operator."""
  return lambda zip_path: make_binary_op_tests(zip_path, binary_operator)


def make_gather_tests(zip_path):
  """Make a set of tests to do gather."""

  test_parameters = [{
      # TODO(mgubin): add string tests when they are supported by Toco.
      # TODO(mgubin): add tests for Nd indices when they are supported by
      # TfLite.
      # TODO(mgubin): add tests for axis != 0 when it is supported by TfLite.
      "params_dtype": [tf.float32, tf.int32],
      "params_shape": [[10], [1, 2, 20]],
      "indices_dtype": [tf.int32],
      "indices_shape": [[3], [5]],
      "axis": [0],  # axis!=0 is GatherV2
  }]

  def build_graph(parameters):
    """Build the gather op testing graph."""
    params = tf.placeholder(
        dtype=parameters["params_dtype"],
        name="params",
        shape=parameters["params_shape"])
    indices = tf.placeholder(
        dtype=parameters["indices_dtype"],
        name="indices",
        shape=parameters["indices_shape"])
    out = tf.gather(params, indices, axis=parameters["axis"])
    return [params, indices], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    params = create_tensor_data(parameters["params_dtype"],
                                parameters["params_shape"])
    indices = create_tensor_data(parameters["indices_dtype"],
                                 parameters["indices_shape"], 0,
                                 parameters["params_shape"][0] - 1)
    return [params, indices], sess.run(
        outputs, feed_dict=dict(zip(inputs, [params, indices])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_global_batch_norm_tests(zip_path):
  """Make a set of tests to do batch_norm_with_global_normalization."""

  test_parameters = [{
      "dtype": [tf.float32],
      "input_shape": [[1, 1, 6, 2], [3, 4, 5, 4]],
      "epsilon": [0.1, 0.0001],
      "scale_after": [True, False],
  }]

  def build_graph(parameters):
    """Build the global batch norm testing graph."""
    input_shape = parameters["input_shape"]
    scale_shape = input_shape[3]

    scale = create_tensor_data(parameters["dtype"], scale_shape)
    offset = create_tensor_data(parameters["dtype"], scale_shape)
    mean = create_tensor_data(parameters["dtype"], scale_shape)
    variance = create_tensor_data(parameters["dtype"], scale_shape)

    x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
    x_norm = tf.nn.batch_norm_with_global_normalization(
        x, mean, variance, scale, offset,
        parameters["epsilon"], parameters["scale_after"])

    input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
                                  shape=parameters["input_shape"])
    out = tf.add(input_tensor, x_norm)
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_value = create_tensor_data(parameters["dtype"],
                                     parameters["input_shape"])
    return [input_value], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_value])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_fused_batch_norm_tests(zip_path):
  """Make a set of tests to do fused_batch_norm."""

  test_parameters = [{
      "dtype": [tf.float32],
      "input_shape": [[1, 1, 6, 2]],
      "epsilon": [0.001, 0.1],
  }]

  def build_graph(parameters):
    """Build the testing graph for fused batch normalization."""
    input_shape = parameters["input_shape"]
    scale_shape = input_shape[3]

    scale = create_tensor_data(parameters["dtype"], scale_shape)
    offset = create_tensor_data(parameters["dtype"], scale_shape)
    mean = create_tensor_data(parameters["dtype"], scale_shape)
    variance = create_tensor_data(parameters["dtype"], scale_shape)

    x = create_tensor_data(parameters["dtype"], parameters["input_shape"])
    [x_norm, _, _] = tf.nn.fused_batch_norm(
        x, scale, offset, mean, variance,
        parameters["epsilon"], data_format="NHWC", is_training=False)

    input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
                                  shape=parameters["input_shape"])
    out = tf.add(input_tensor, x_norm)
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_value = create_tensor_data(parameters["dtype"],
                                     parameters["input_shape"])
    return [input_value], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_value])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_conv_tests(zip_path):
  """Make a set of tests to do convolution."""

  test_parameters = [
      {
          "input_shape": [[1, 3, 4, 3]],
          "filter_shape": [[1, 1, 3, 2]],
          "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
          "padding": ["SAME", "VALID"],
          "data_format": ["NHWC"],  # TODO(aselle): NCHW  would be good
          "constant_filter": [True, False],
      },
      {
          "input_shape": [[2, 14, 14, 2]],
          "filter_shape": [[6, 6, 2, 2]],
          "strides": [[1, 1, 1, 1], [1, 2, 3, 1]],
          "padding": ["SAME", "VALID"],
          "data_format": ["NHWC"],  # TODO(aselle): NCHW  would be good
          "constant_filter": [True, False],
      }
  ]

  def build_graph(parameters):
    """Build a conv graph given `parameters`."""
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=parameters["input_shape"])

    # Get filter input either as a placeholder or constants. Also get a list of
    # the input tensors that are represented as placeholders.
    if parameters["constant_filter"]:
      filter_input = create_tensor_data(np.float32, parameters["filter_shape"])
      input_tensors = [input_tensor]
    else:
      filter_input = tf.placeholder(
          dtype=tf.float32, name="filter", shape=parameters["filter_shape"])
      input_tensors = [input_tensor, filter_input]

    out = tf.nn.conv2d(
        input_tensor,
        filter_input,
        strides=parameters["strides"],
        padding=parameters["padding"],
        data_format=parameters["data_format"])
    return input_tensors, [out]

  def build_inputs(parameters, sess, inputs, outputs):
    # Build list of input values either containing 1 tensor (input) or 2 tensors
    # (input, filter) based on whether filter is constant or variable input.
    values = [create_tensor_data(np.float32, parameters["input_shape"])]
    if not parameters["constant_filter"]:
      values.append(create_tensor_data(np.float32, parameters["filter_shape"]))
    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_depthwiseconv_tests(zip_path):
  """Make a set of tests to do convolution."""

  # Tensorflow only supports equal strides
  test_parameters = [
      {
          "input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
          "filter_size": [[1, 1], [1, 2], [3, 3]],
          "strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
          "channel_multiplier": [1, 2],
          "rate": [[1, 1]],
          "padding": ["SAME", "VALID"],
          "data_format": ["NHWC"],
          "constant_filter": [True, False],
      },
      {
          "input_shape": [[1, 3, 4, 3]],
          "filter_size": [[1, 1]],
          "strides": [[1, 1, 2, 1]],  # TF needs [1, x, x, 1]
          "channel_multiplier": [2],
          "rate": [[2, 2]],  #  Only [1, 1] is supported
          "padding": ["SAME"],
          "data_format": ["NHWC"],
          "constant_filter": [True, False],
      }
  ]

  def get_tensor_shapes(parameters):
    input_shape = parameters["input_shape"]
    filter_size = parameters["filter_size"]
    filter_shape = filter_size + [
        input_shape[3], parameters["channel_multiplier"]
    ]
    return [input_shape, filter_shape]

  def build_graph(parameters):
    """Build a depthwise conv graph given `parameters`."""
    input_shape, filter_shape = get_tensor_shapes(parameters)
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=input_shape)

    # Get filter input either as a placeholder or constants. Also get a list of
    # the input tensors that are represented as placeholders.
    if parameters["constant_filter"]:
      filter_input = create_tensor_data(np.float32, filter_shape)
      input_tensors = [input_tensor]
    else:
      filter_input = tf.placeholder(
          dtype=tf.float32, name="filter", shape=filter_shape)
      input_tensors = [input_tensor, filter_input]

    out = tf.nn.depthwise_conv2d(
        input_tensor,
        filter_input,
        strides=parameters["strides"],
        rate=parameters["rate"],
        padding=parameters["padding"],
        data_format=parameters["data_format"])
    return input_tensors, [out]

  def build_inputs(parameters, sess, inputs, outputs):
    # Build list of input values either containing 1 tensor (input) or 2 tensors
    # (input, filter) based on whether filter is constant or variable input.
    input_shape, filter_shape = get_tensor_shapes(parameters)
    values = [create_tensor_data(np.float32, input_shape)]
    if not parameters["constant_filter"]:
      values.append(create_tensor_data(np.float32, filter_shape))
    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_concatenation_tests(zip_path):
  """Make a set of tests to do concatenatinon."""

  test_parameters = [{
      "base_shape": [[1, 3, 4, 3], [3, 4]],
      "num_tensors": [1, 2, 3, 4, 5, 6],
      "axis": [0, 1, 2, 3],
  }]

  def get_shape(parameters, delta):
    """Return a tweaked version of 'base_shape'."""
    axis = parameters["axis"]
    shape = parameters["base_shape"][:]
    if axis < len(shape):
      shape[axis] += delta
    return shape

  def build_graph(parameters):
    all_tensors = []
    for n in range(0, parameters["num_tensors"]):
      input_tensor = tf.placeholder(dtype=tf.float32, name=("input%d" % n),
                                    shape=get_shape(parameters, n))
      all_tensors.append(input_tensor)
    out = tf.concat(all_tensors, parameters["axis"])
    return all_tensors, [out]

  def build_inputs(parameters, sess, inputs, outputs):
    all_values = []
    for n in range(0, parameters["num_tensors"]):
      input_values = create_tensor_data(np.float32,
                                        get_shape(parameters, n))
      all_values.append(input_values)
    return all_values, sess.run(
        outputs, feed_dict=dict(zip(inputs, all_values)))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_fully_connected_tests(zip_path):
  """Make a set of tests to do fully_connected."""

  test_parameters = [{
      "shape1": [[3, 3]],
      "shape2": [[3, 3]],
      "transpose_a": [True, False],
      "transpose_b": [True, False],
      "constant_filter": [True, False],
  }, {
      "shape1": [[4, 4], [1, 4], [4]],
      "shape2": [[4, 4], [4, 1], [4]],
      "transpose_a": [False],
      "transpose_b": [False],
      "constant_filter": [True, False],
  }, {
      "shape1": [[40, 37]],
      "shape2": [[37, 40]],
      "transpose_a": [False],
      "transpose_b": [False],
      "constant_filter": [True, False],
  }]

  def build_graph(parameters):
    """Build a matmul graph given `parameters`."""
    input_tensor1 = tf.placeholder(dtype=tf.float32, name="input1",
                                   shape=parameters["shape1"])

    # Get input_tensor2 either as a placeholder or constants. Also get a list of
    # the input tensors that are represented as placeholders.
    if parameters["constant_filter"]:
      input_tensor2 = create_tensor_data(np.float32, parameters["shape2"])
      input_tensors = [input_tensor1]
    else:
      input_tensor2 = tf.placeholder(
          dtype=tf.float32, name="input2", shape=parameters["shape2"])
      input_tensors = [input_tensor1, input_tensor2]

    out = tf.matmul(input_tensor1, input_tensor2,
                    transpose_a=parameters["transpose_a"],
                    transpose_b=parameters["transpose_b"])
    return input_tensors, [out]

  def build_inputs(parameters, sess, inputs, outputs):
    # Build list of input values either containing 1 tensor (input_values1) or 2
    # tensors (input_values1, input_values2) based on whether the second input
    # is a constant or variable input.
    values = [create_tensor_data(np.float32, shape=parameters["shape1"])]
    if not parameters["constant_filter"]:
      values.append(create_tensor_data(np.float32, parameters["shape2"]))
    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_l2norm_tests(zip_path):
  """Make a set of tests to do l2norm."""

  # Chose a set of parameters
  test_parameters = [{
      "input_shape": [[5, 7], [1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3],
                      [3, 1, 2, 4, 6], [2, 2, 3, 4, 5, 6]],
      "dim": [0, 1, 2, 3, [2, 3], -2],
      "epsilon": [None, 1e-12, 1e-3],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=parameters["input_shape"])
    if parameters["epsilon"]:
      out = tf.nn.l2_normalize(
          input_tensor, parameters["dim"], epsilon=parameters["epsilon"])
    else:
      out = tf.nn.l2_normalize(input_tensor, parameters["dim"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(
        np.float32, parameters["input_shape"], min_value=-4, max_value=10)
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_local_response_norm_tests(zip_path):
  """Make a set of tests to do local_response_norm."""

  # Chose a set of parameters
  test_parameters = [{
      "input_shape": [[1, 1, 1, 1], [1, 3, 4, 3], [3, 15, 14, 3]],
      "depth_radius": [None, 0, 1, 3, 4, 5],
      "bias": [None, 0.1, 0.3, -0.1],
      "alpha": [None, 1, 2, -3],
      "beta": [None, 0.5, 0.25, 2],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=tf.float32, name="input", shape=parameters["input_shape"])
    out = tf.nn.local_response_normalization(
        input_tensor, depth_radius=parameters["depth_radius"],
        bias=parameters["bias"], alpha=parameters["alpha"],
        beta=parameters["beta"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(
        np.float32, parameters["input_shape"], min_value=-4, max_value=10)
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_pad_tests(zip_path):
  """Make a set of tests to do pad."""

  # TODO(nupurgarg): Add test for tf.uint8.
  test_parameters = [
      {
          "dtype": [tf.int32, tf.int64, tf.float32],
          "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
          "paddings": [[[0, 0], [0, 1], [2, 3], [0, 0]], [[0, 1], [0, 0],
                                                          [0, 0], [2, 3]]],
      },
      # Non-4D use case.
      {
          "dtype": [tf.int32, tf.int64, tf.float32],
          "input_shape": [[1, 2], [0, 1, 2]],
          "paddings": [[[0, 1], [2, 3]]],
      },
  ]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=parameters["dtype"],
        name="input",
        shape=parameters["input_shape"])
    out = tf.pad(input_tensor, paddings=parameters["paddings"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_reshape_tests(zip_path):
  """Make a set of tests to do reshape."""

  # Alll shapes below are suitable for tensors with 420 elements.
  test_parameters = [{
      "dtype": [tf.float32, tf.int32],
      "input_shape": [[3, 4, 5, 7], [4, 105], [21, 5, 2, 2], [420]],
      "output_shape": [[15, 28], [420], [1, -1, 5, 7], [-1]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
                                  shape=parameters["input_shape"])
    out = tf.reshape(input_tensor, shape=parameters["output_shape"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_resize_bilinear_tests(zip_path):
  """Make a set of tests to do resize_bilinear."""

  test_parameters = [{
      "dtype": [tf.float32, tf.int32],
      "input_shape": [[1, 3, 4, 3], [1, 10, 2, 1]],
      "size": [[1, 1], [4, 3], [2, 2], [5, 6]],
      "align_corners": [None, True, False],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
                                  shape=parameters["input_shape"])
    out = tf.image.resize_bilinear(input_tensor, size=parameters["size"],
                                   align_corners=parameters["align_corners"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_sigmoid_tests(zip_path):
  """Make a set of tests to do sigmoid."""

  test_parameters = [{
      "dtype": [tf.float32],
      "input_shape": [[1, 3, 4, 3], [4], [], [1, 2, 3, 4, 5, 6]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
                                  shape=parameters["input_shape"])
    out = tf.sigmoid(input_tensor)
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_softmax_tests(zip_path):
  """Make a set of tests to do softmax."""

  test_parameters = [{
      "dtype": [tf.float32],
      "input_shape": [[1, 3, 4, 3], [2, 3]],
      "dim": [-1, 0],
  }, {
      "dtype": [tf.float32],
      "input_shape": [[4, 7]],
      "dim": [-1, 1],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
                                  shape=parameters["input_shape"])
    out = tf.nn.softmax(input_tensor, dim=parameters["dim"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_space_to_depth_tests(zip_path):
  """Make a set of tests to do space_to_depth."""

  test_parameters = [{
      "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64],
      "input_shape": [[2, 12, 24, 1]],
      "block_size": [2, 3, 4],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(dtype=parameters["dtype"], name="input",
                                  shape=parameters["input_shape"])
    out = tf.space_to_depth(input_tensor, block_size=parameters["block_size"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_space_to_batch_nd_tests(zip_path):
  """Make a set of tests to do space_to_batch_nd."""

  # TODO(nupurgarg): Add test for uint8.
  test_parameters = [
      {
          "dtype": [tf.int32, tf.int64, tf.float32],
          "input_shape": [[1, 2, 2, 3], [2, 2, 4, 1]],
          "block_shape": [[1, 3], [2, 2]],
          "paddings": [[[0, 0], [0, 0]], [[0, 0], [2, 0]], [[1, 1], [1, 1]]],
      },
      {
          "dtype": [tf.float32],
          "input_shape": [[2, 3, 7, 3]],
          "block_shape": [[1, 3], [2, 2]],
          "paddings": [[[0, 0], [2, 0]], [[1, 0], [1, 0]]],
      },
      # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others.
      {
          "dtype": [tf.float32],
          "input_shape": [[1, 4, 4, 4, 1, 1]],
          "block_shape": [[2, 2, 2]],
          "paddings": [[[0, 0], [0, 0], [0, 0]]],
      },
  ]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=parameters["dtype"],
        name="input",
        shape=parameters["input_shape"])
    out = tf.space_to_batch_nd(input_tensor, parameters["block_shape"],
                               parameters["paddings"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_batch_to_space_nd_tests(zip_path):
  """Make a set of tests to do batch_to_space_nd."""

  test_parameters = [
      {
          "dtype": [tf.float32, tf.int64, tf.int32],
          "input_shape": [[12, 2, 2, 1]],
          "block_shape": [[1, 4], [2, 2], [3, 4]],
          "crops": [[[0, 0], [0, 0]], [[1, 1], [1, 1]]],
      },
      # Non-4D use case: 1 bath dimension, 3 spatial dimensions, 2 others.
      {
          "dtype": [tf.float32],
          "input_shape": [[8, 2, 2, 2, 1, 1]],
          "block_shape": [[2, 2, 2]],
          "crops": [[[0, 0], [0, 0], [0, 0]]],
      },
  ]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=parameters["dtype"],
        name="input",
        shape=parameters["input_shape"])
    out = tf.batch_to_space_nd(input_tensor, parameters["block_shape"],
                               parameters["crops"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_transpose_tests(zip_path):
  """Make a set of tests to do transpose."""

  # TODO(nupurgarg): Add test for uint8.
  test_parameters = [{
      "dtype": [tf.int32, tf.int64, tf.float32],
      "input_shape": [[2, 2, 3]],
      "perm": [[0, 1, 2], [0, 2, 1]],
  }, {
      "dtype": [tf.float32],
      "input_shape": [[1, 2, 3, 4]],
      "perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
  }, {
      "dtype": [tf.float32],
      "input_shape": [[1, 2, 3, 4, 5]],
      "perm": [[0, 1, 2, 3, 4]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=parameters["dtype"],
        name="input",
        shape=parameters["input_shape"])
    out = tf.transpose(input_tensor, perm=parameters["perm"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_squeeze_tests(zip_path):
  """Make a set of tests to do squeeze."""

  test_parameters = [{
      "dtype": [tf.int32, tf.float32, tf.int64],
      "input_shape": [[1, 2, 1, 3, 1, 4, 1, 1]],
      "axis": [
          None, [], [0, 2], [4, 7], [-1, 0, 2, 0, 7, -6], [1], [2, 3, 2],
          [-1, -2, -4, -6, -8], [0, 2, 4, 6, 7], [7, 6, 4, 2, 0], [6, 6],
          [0, 1, 2, 3, 4, 5, 6, 7], [-2, -3, 1, 0, 7, -5]
      ],
  }, {
      "dtype": [tf.int32, tf.float32, tf.int64],
      "input_shape": [[1]],
      "axis": [None, [], [0], [-1]],
  }, {
      "dtype": [tf.int32, tf.float32, tf.int64],
      "input_shape": [[1, 1, 1, 1, 1]],
      "axis": [None, [], [0], [3, 0], [-2, 0, 3, 2]],
  }]

  def build_graph(parameters):
    input_tensor = tf.placeholder(
        dtype=parameters["dtype"],
        name="input",
        shape=parameters["input_shape"])
    out = tf.squeeze(input_tensor, axis=parameters["axis"])
    return [input_tensor], [out]

  def build_inputs(parameters, sess, inputs, outputs):
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    return [input_values], sess.run(
        outputs, feed_dict=dict(zip(inputs, [input_values])))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_strided_slice_tests(zip_path):
  """Make a set of tests to do strided_slice."""

  # TODO(soroosh): add test/support for uint8.
  test_parameters = [
      # 4-D
      {
          "dtype": [tf.float32, tf.int32, tf.int64],
          "index_type": [tf.int32],
          "input_shape": [[12, 2, 2, 5]],
          "begin": [[0, 0, 0, 0], [1, 0, 1, 0]],
          "end": [[8, 2, 2, 3], [12, 2, 2, 5]],
          "strides": [None, [1, 1, 1, 1], [2, 1, 3, 1]],
          "begin_mask": [None, 0, 1, 2, 8],
          "end_mask": [None, 0, 1, 2, 8],
      },
      # 2-D
      {
          "dtype": [tf.float32, tf.int32, tf.int64],
          "index_type": [tf.int32],
          "input_shape": [[2, 3]],
          "begin": [[0, 0], [1, 0]],
          "end": [[2, 3], [2, 2]],
          "strides": [None, [1, 1], [2, 2]],
          "begin_mask": [None, 0, 1, 2],
          "end_mask": [None, 0, 1, 2],
      },
      # Negative strides
      {
          "dtype": [tf.float32, tf.int32, tf.int64],
          "index_type": [tf.int32],
          "input_shape": [[2, 3]],
          "begin": [[0, -1]],
          "end": [[2, -3]],
          "strides": [[1, -1]],
          "begin_mask": [None, 0, 1, 2],
          "end_mask": [None, 0, 1, 2],
      },
  ]

  def build_graph(parameters):
    """Build graph for stride_slice test."""
    input_tensor = tf.placeholder(
        dtype=parameters["dtype"],
        name="input",
        shape=parameters["input_shape"])
    begin = tf.placeholder(
        dtype=parameters["index_type"],
        name="begin",
        shape=[len(parameters["input_shape"])])
    end = tf.placeholder(
        dtype=parameters["index_type"],
        name="end",
        shape=[len(parameters["input_shape"])])
    strides = (
        tf.placeholder(
            dtype=parameters["index_type"],
            name="strides",
            shape=[len(parameters["input_shape"])])
        if parameters["strides"] is not None else None)
    tensors = [input_tensor, begin, end]
    if strides is not None:
      tensors.append(strides)
    out = tf.strided_slice(
        input_tensor,
        begin,
        end,
        strides,
        begin_mask=parameters["begin_mask"],
        end_mask=parameters["end_mask"])
    return tensors, [out]

  def build_inputs(parameters, sess, inputs, outputs):
    """Build inputs for stride_slice test."""
    input_values = create_tensor_data(parameters["dtype"],
                                      parameters["input_shape"])
    index_type = _TF_TYPE_INFO[parameters["index_type"]][0]
    begin_values = np.array(parameters["begin"]).astype(index_type)
    end_values = np.array(parameters["end"]).astype(index_type)
    stride_values = (
        np.array(parameters["strides"]).astype(index_type)
        if parameters["strides"] is not None else None)
    values = [input_values, begin_values, end_values]
    if stride_values is not None:
      values.append(stride_values)

    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))

  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)


def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
  """Given an input perform a sequence of TensorFlow ops to produce l2pool."""
  return tf.sqrt(tf.nn.avg_pool(
      tf.square(input_tensor), ksize=ksize, strides=strides,
      padding=padding, data_format=data_format))


# Toco binary path provided by the generate rule.
bin_path = None


def main(unused_args):
  global bin_path
  def mkdir_if_not_exist(x):
    if not os.path.isdir(x):
      os.mkdir(x)
      if not os.path.isdir(x):
        raise RuntimeError("Failed to create dir %r" % x)

  if FLAGS.type == "zipped":
    opstest_path = os.path.join(FLAGS.output_path)
    mkdir_if_not_exist(opstest_path)
    def _path(filename):
      return os.path.join(opstest_path, filename)

    dispatch = {
        "control_dep.zip": make_control_dep_tests,
        "add.zip": make_binary_op_tests_func(tf.add),
        "space_to_batch_nd.zip": make_space_to_batch_nd_tests,
        "div.zip": make_binary_op_tests_func(tf.div),
        "sub.zip": make_binary_op_tests_func(tf.subtract),
        "batch_to_space_nd.zip": make_batch_to_space_nd_tests,
        "conv.zip": make_conv_tests,
        "constant.zip": make_constant_tests,
        "depthwiseconv.zip": make_depthwiseconv_tests,
        "concat.zip": make_concatenation_tests,
        "fully_connected.zip": make_fully_connected_tests,
        "global_batch_norm.zip": make_global_batch_norm_tests,
        "gather.zip": make_gather_tests,
        "fused_batch_norm.zip": make_fused_batch_norm_tests,
        "l2norm.zip": make_l2norm_tests,
        "local_response_norm.zip": make_local_response_norm_tests,
        "mul.zip": make_binary_op_tests_func(tf.multiply),
        "relu.zip": make_relu_tests,
        "relu1.zip": make_relu1_tests,
        "relu6.zip": make_relu6_tests,
        "l2_pool.zip": make_pool_tests(make_l2_pool),
        "avg_pool.zip": make_pool_tests(tf.nn.avg_pool),
        "max_pool.zip": make_pool_tests(tf.nn.max_pool),
        "pad.zip": make_pad_tests,
        "reshape.zip": make_reshape_tests,
        "resize_bilinear.zip": make_resize_bilinear_tests,
        "sigmoid.zip": make_sigmoid_tests,
        "softmax.zip": make_softmax_tests,
        "space_to_depth.zip": make_space_to_depth_tests,
        "transpose.zip": make_transpose_tests,
        "mean.zip": make_mean_tests,
        "squeeze.zip": make_squeeze_tests,
        "strided_slice.zip": make_strided_slice_tests,
    }
    out = FLAGS.zip_to_output
    bin_path = FLAGS.toco
    if out in dispatch:
      dispatch[out](_path(out))
    else:
      raise RuntimeError("Invalid zip to output %r" % out)

  else:
    raise RuntimeError("Invalid argument for type of generation.")


if __name__ == "__main__":
  FLAGS, unparsed = parser.parse_known_args()

  if unparsed:
    print("Usage: %s <path out> zipped <zip file to generate>")
  else:
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)