aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/convert_test.py
blob: 40a8b5fafb2dbf3b30dfae4ad307737b18782480 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite Python Interface: Sanity check."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.contrib.lite.python import convert
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python import op_hint
from tensorflow.contrib.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test


class ConvertTest(test_util.TensorFlowTestCase):

  def testBasic(self):
    in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
                                      dtype=dtypes.float32)
    out_tensor = in_tensor + in_tensor
    sess = session.Session()

    # Try running on valid graph
    tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
                                        [out_tensor])
    self.assertTrue(tflite_model)

    # TODO(aselle): remove tests that fail (we must get TOCO to not fatal
    # all the time).
    # Try running on identity graph (known fail)
    # with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
    #   result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])

  def testQuantization(self):
    in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
                                      dtype=dtypes.float32)
    out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
                                                        min=0., max=1.)
    sess = session.Session()

    tflite_model = convert.toco_convert(
        sess.graph_def, [in_tensor], [out_tensor],
        inference_type=lite_constants.QUANTIZED_UINT8,
        quantized_input_stats=[(0., 1.)])
    self.assertTrue(tflite_model)

  def testGraphDefBasic(self):
    in_tensor = array_ops.placeholder(
        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
    _ = in_tensor + in_tensor
    sess = session.Session()

    tflite_model = convert.toco_convert_graph_def(
        sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
        inference_type=lite_constants.FLOAT)
    self.assertTrue(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertEqual(1, len(input_details))
    self.assertEqual("input", input_details[0]["name"])
    self.assertEqual(np.float32, input_details[0]["dtype"])
    self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
    self.assertEqual((0., 0.), input_details[0]["quantization"])

    output_details = interpreter.get_output_details()
    self.assertEqual(1, len(output_details))
    self.assertEqual("add", output_details[0]["name"])
    self.assertEqual(np.float32, output_details[0]["dtype"])
    self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
    self.assertEqual((0., 0.), output_details[0]["quantization"])

  def testGraphDefQuantization(self):
    in_tensor_1 = array_ops.placeholder(
        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
    in_tensor_2 = array_ops.placeholder(
        shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
    _ = array_ops.fake_quant_with_min_max_args(
        in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
    sess = session.Session()

    input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
    output_arrays = ["output"]
    tflite_model = convert.toco_convert_graph_def(
        sess.graph_def,
        input_arrays_map,
        output_arrays,
        inference_type=lite_constants.QUANTIZED_UINT8,
        quantized_input_stats=[(0., 1.), (0., 1.)])
    self.assertTrue(tflite_model)

    # Check values from converted model.
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()
    self.assertEqual(2, len(input_details))
    self.assertEqual("inputA", input_details[0]["name"])
    self.assertEqual(np.uint8, input_details[0]["dtype"])
    self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
    self.assertEqual((1., 0.),
                     input_details[0]["quantization"])  # scale, zero_point

    self.assertEqual("inputB", input_details[1]["name"])
    self.assertEqual(np.uint8, input_details[1]["dtype"])
    self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
    self.assertEqual((1., 0.),
                     input_details[1]["quantization"])  # scale, zero_point

    output_details = interpreter.get_output_details()
    self.assertEqual(1, len(output_details))
    self.assertEqual("output", output_details[0]["name"])
    self.assertEqual(np.uint8, output_details[0]["dtype"])
    self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
    self.assertTrue(output_details[0]["quantization"][0] > 0)  # scale


class ConvertTestOpHint(test_util.TensorFlowTestCase):
  """Test the hint to stub functionality."""

  def _getGraphOpTypes(self, graphdef, output_nodes):
    """Returns used op types in `graphdef` reachable from `output_nodes`.

    This is used to check that after the stub transformation the expected
    nodes are there. Typically use this with self.assertCountEqual(...).

    NOTE: this is not a exact test that the graph is the correct output, but
      it balances compact expressibility of test with sanity checking.

    Args:
      graphdef: TensorFlow proto graphdef.
      output_nodes: A list of output node names that we need to reach.

    Returns:
      A set of node types reachable from `output_nodes`.
    """
    name_to_input_name, name_to_node, _ = (
        _extract_graph_summary(graphdef))
    # Find all nodes that are needed by the outputs
    used_node_names = _bfs_for_reachable_nodes(output_nodes, name_to_input_name)
    return set([name_to_node[node_name].op for node_name in used_node_names])

  def _countIdentities(self, nodes):
    """Count the number of "Identity" op types in the list of proto nodes.

    Args:
      nodes: NodeDefs of the graph.

    Returns:
      The number of nodes with op type "Identity" found.
    """
    return len([x for x in nodes if x.op == "Identity"])

  def testSwishLiteHint(self):
    """Makes a custom op swish and makes sure it gets converted as a unit."""
    image = array_ops.constant([1., 2., 3., 4.])
    swish_scale = array_ops.constant(1.0)

    def _swish(input_tensor, scale):
      custom = op_hint.OpHint("cool_activation")
      input_tensor, scale = custom.add_inputs(input_tensor, scale)
      output = math_ops.sigmoid(input_tensor) * input_tensor * scale
      output, = custom.add_outputs(output)
      return output
    output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")

    with self.cached_session() as sess:
      # check if identities have been put into the graph (2 input, 1 output,
      # and 1 final output).
      self.assertEqual(self._countIdentities(sess.graph_def.node), 4)

      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
          graph_def=sess.graph_def)

      self.assertCountEqual(
          self._getGraphOpTypes(
              stubbed_graphdef,
              output_nodes=[op_hint._tensor_name_base(output.name)]),
          ["cool_activation", "Const", "Identity"])

  def testScaleAndBiasAndIdentity(self):
    """This tests a scaled add which has 3 inputs and 2 outputs."""
    a = array_ops.constant(1.)
    x = array_ops.constant([2., 3.])
    b = array_ops.constant([4., 5.])

    def _scaled_and_bias_and_identity(a, x, b):
      custom = op_hint.OpHint("scale_and_bias_and_identity")
      a, x, b = custom.add_inputs(a, x, b)
      return custom.add_outputs(a * x + b, x)
    output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
                                name="ModelOutput")

    with self.cached_session() as sess:
      # make sure one identity for each input (3) and output (2) => 3 + 2 = 5
      # +1 for the final output
      self.assertEqual(self._countIdentities(sess.graph_def.node), 6)

      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
          graph_def=sess.graph_def)

      self.assertCountEqual(
          self._getGraphOpTypes(
              stubbed_graphdef,
              output_nodes=[op_hint._tensor_name_base(output.name)]),
          ["scale_and_bias_and_identity", "Const", "Identity", "Pack"])

  def testTwoFunctions(self):
    """Tests if two functions are converted correctly."""
    a = array_ops.constant([1.])
    b = array_ops.constant([1.])
    def _double_values(x):
      custom = op_hint.OpHint("add_test")
      x, = custom.add_inputs(x)
      output = math_ops.multiply(x, x)
      output, = custom.add_outputs(output)
      return output
    output = array_ops.identity(
        math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")

    with self.cached_session() as sess:
      # make sure one identity for each input (2) and output (2) => 2 + 2
      # +1 for the final output
      self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
          graph_def=sess.graph_def)
      self.assertCountEqual(
          self._getGraphOpTypes(
              stubbed_graphdef,
              output_nodes=[op_hint._tensor_name_base(output.name)]),
          ["add_test", "Const", "Identity", "Add"])

  def _get_input_index(self, x):
    return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i

  def _get_output_index(self, x):
    return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i

  def _get_sort_index(self, x):
    return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i

  def testTags(self):
    """Test if multiple args with the same tag are grouped."""
    a = array_ops.constant([1.])
    b = array_ops.constant([2.])
    c = array_ops.constant([3.])
    d = array_ops.constant([4.])
    custom = op_hint.OpHint("test_tag")
    a = custom.add_input(a, tag="mytag",
                         aggregate=op_hint.OpHint.AGGREGATE_STACK)
    b, = custom.add_inputs(b)
    c = custom.add_input(c, tag="mytag",
                         aggregate=op_hint.OpHint.AGGREGATE_STACK)
    d = custom.add_input(d, tag="mytag2",
                         aggregate=op_hint.OpHint.AGGREGATE_STACK)
    res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
    custom.add_outputs([res])
    with self.cached_session():
      self.assertEqual(self._get_input_index(a), 0)
      self.assertEqual(self._get_sort_index(a), 0)
      self.assertEqual(self._get_input_index(b), 1)
      self.assertEqual(self._get_input_index(c), 0)
      self.assertEqual(self._get_sort_index(c), 1)

  def testOverrideIndex(self):
    a = array_ops.constant([1.])
    b = array_ops.constant([2.])
    c = array_ops.constant([3.])
    custom = op_hint.OpHint("test_override")
    b = custom.add_input(b)  # should auto assign 0
    a = custom.add_input(a, index_override=1)
    c = custom.add_input(c)  # should auto assign 2
    with self.cached_session():
      self.assertEqual(self._get_input_index(a), 1)
      self.assertEqual(self._get_input_index(b), 0)
      self.assertEqual(self._get_input_index(c), 2)

  def testAggregate(self):
    a = array_ops.constant([3., 4.])
    b = array_ops.constant([5., 6.])
    hint = op_hint.OpHint("agg")
    a0, a1 = array_ops.unstack(a)
    b0, b1 = array_ops.unstack(b)

    a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)

    c0 = math_ops.add(a0, b0, name="addleft")
    c1 = math_ops.add(a1, b1, name="addright")
    c0 = hint.add_output(
        c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
    c1 = hint.add_output(
        c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)

    curr = array_ops.stack([c0, c1])
    output = array_ops.identity(curr, name="FINAL_OUTPUT")
    with self.cached_session() as sess:
      stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
          graph_def=sess.graph_def)
      self.assertCountEqual(
          self._getGraphOpTypes(
              stubbed_graphdef,
              output_nodes=[op_hint._tensor_name_base(output.name)]),
          ["agg", "Const", "Identity"])


if __name__ == "__main__":
  test.main()