diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-05-30 17:54:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-30 17:56:47 -0700 |
commit | 316549d36f6ab3d250ce9e33b768bbfb1a4d7362 (patch) | |
tree | cef32a4c8ace3dedac532c14fd39944d5bc4ed2b /tensorflow/contrib/lite/python/lite_test.py | |
parent | 2a484497062677f5cf0205ee3b9c28a64f03fe04 (diff) |
Enable TOCO pip command line binding.
PiperOrigin-RevId: 198649827
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 180 |
1 files changed, 171 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 2f3105f3e6..28386ecb1a 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -29,8 +29,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model +from tensorflow.python.training.training_util import write_graph class FromSessionTest(test_util.TensorFlowTestCase): @@ -65,16 +67,22 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertEqual((0., 0.), output_details[0]['quantization']) def testQuantization(self): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input') + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') out_tensor = array_ops.fake_quant_with_min_max_args( - in_tensor + in_tensor, min=0., max=1., name='output') + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') sess = session.Session() # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) converter.inference_type = lite_constants.QUANTIZED_UINT8 - converter.quantized_input_stats = [(0., 1.)] # mean, std_dev + converter.quantized_input_stats = { + 'inputA': (0., 1.), + 'inputB': (0., 1.) + } # mean, std_dev tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -83,13 +91,19 @@ class FromSessionTest(test_util.TensorFlowTestCase): interpreter.allocate_tensors() input_details = interpreter.get_input_details() - self.assertEqual(1, len(input_details)) - self.assertEqual('input', input_details[0]['name']) + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) self.assertEqual(np.uint8, input_details[0]['dtype']) self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) self.assertEqual((1., 0.), input_details[0]['quantization']) # scale, zero_point + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.uint8, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((1., 0.), + input_details[1]['quantization']) # scale, zero_point + output_details = interpreter.get_output_details() self.assertEqual(1, len(output_details)) self.assertEqual('output', output_details[0]['name']) @@ -97,6 +111,26 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + def testQuantizationInvalid(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'Quantization input stats are not available for input tensors ' + '\'inputB\'.', str(error.exception)) + def testBatchSizeInvalid(self): in_tensor = array_ops.placeholder( shape=[None, 16, 16, 3], dtype=dtypes.float32) @@ -152,8 +186,7 @@ class FromSessionTest(test_util.TensorFlowTestCase): sess = session.Session() # Convert model and ensure model is not None. - converter = lite.TocoConverter.from_session( - sess, [in_tensor], [out_tensor], freeze_variables=True) + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) tflite_model = converter.convert() self.assertTrue(tflite_model) @@ -188,6 +221,135 @@ class FromSessionTest(test_util.TensorFlowTestCase): self.assertTrue(graphviz_output) +class FromFlatbufferFile(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_flatbuffer_file( + graph_def_file, ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFloatWithShapesArray(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_flatbuffer_file( + graph_def_file, ['Placeholder'], ['add'], + input_shapes={'Placeholder': [1, 16, 16, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + var + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Ensure the graph with variables cannot be converted. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual('Please freeze the graph using freeze_graph.py', + str(error.exception)) + + def testPbtxt(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') + write_graph(sess.graph_def, '', graph_def_file, True) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_flatbuffer_file( + graph_def_file, ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testInvalidFile(self): + graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') + with gfile.Open(graph_def_file, 'wb') as temp_file: + temp_file.write('bad data') + temp_file.flush() + + # Attempts to convert the invalid model. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_flatbuffer_file(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual( + 'Unable to parse input file \'{}\'.'.format(graph_def_file), + str(error.exception)) + + class FromSavedModelTest(test_util.TensorFlowTestCase): def _createSavedModel(self, shape): |