diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-08-28 18:16:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 18:23:02 -0700 |
commit | 2e7352e57c541908cd700bb0fe53a04b456392c9 (patch) | |
tree | 2064e341d1a7f154d0cb1d42910359c7dd3e5a02 /tensorflow/contrib/lite/python/lite_test.py | |
parent | c4099e6ee8ba3846f2b7e70445806bc3055c5624 (diff) |
Add more model support to TocoConverter.
PiperOrigin-RevId: 210643904
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 2f13684228..8c9cfa943f 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -35,11 +35,51 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer from tensorflow.python.platform import gfile +from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model from tensorflow.python.training.training_util import write_graph +class FromConstructor(test_util.TensorFlowTestCase): + + # Tests invalid constructors using a dummy value for the GraphDef. + def testInvalidConstructor(self): + message = ('If input_tensors and output_tensors are None, both ' + 'input_arrays_with_shape and output_arrays must be defined.') + + # `output_arrays` is not defined. + with self.assertRaises(ValueError) as error: + lite.TocoConverter( + None, None, [], input_arrays_with_shape=[('input', [3, 9])]) + self.assertEqual(message, str(error.exception)) + + # `input_arrays_with_shape` is not defined. + with self.assertRaises(ValueError) as error: + lite.TocoConverter(None, [], None, output_arrays=['output']) + self.assertEqual(message, str(error.exception)) + + # Tests valid constructors using a dummy value for the GraphDef. + def testValidConstructor(self): + converter = lite.TocoConverter( + None, + None, + None, + input_arrays_with_shape=[('input', [3, 9])], + output_arrays=['output']) + self.assertFalse(converter._has_valid_tensors()) + self.assertEqual(converter.get_input_arrays(), ['input']) + + with self.assertRaises(ValueError) as error: + converter._set_batch_size(1) + self.assertEqual( + 'The batch size cannot be set for this model. Please use ' + 'input_shapes parameter.', str(error.exception)) + + converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor']) + self.assertTrue(converter._has_valid_tensors()) + + class FromSessionTest(test_util.TensorFlowTestCase): def testFloat(self): @@ -490,6 +530,79 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): 'Unable to parse input file \'{}\'.'.format(graph_def_file), str(error.exception)) + # TODO(nupurgarg): Test model loading in open source. + def _initObjectDetectionArgs(self): + # Initializes the arguments required for the object detection model. + self._graph_def_file = resource_loader.get_path_to_datafile( + 'testdata/tflite_graph.pbtxt') + self._input_arrays = ['normalized_input_image_tensor'] + self._output_arrays = [ + 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', + 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3' + ] + self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]} + + def testTFLiteGraphDef(self): + # Tests the object detection model that cannot be loaded in TensorFlow. + self._initObjectDetectionArgs() + + converter = lite.TocoConverter.from_frozen_graph( + self._graph_def_file, self._input_arrays, self._output_arrays, + self._input_shapes) + converter.allow_custom_ops = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('normalized_input_image_tensor', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(4, len(output_details)) + self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + self.assertEqual('TFLite_Detection_PostProcess:1', + output_details[1]['name']) + self.assertTrue(([1, 10] == output_details[1]['shape']).all()) + self.assertEqual('TFLite_Detection_PostProcess:2', + output_details[2]['name']) + self.assertTrue(([1, 10] == output_details[2]['shape']).all()) + self.assertEqual('TFLite_Detection_PostProcess:3', + output_details[3]['name']) + self.assertTrue(([1] == output_details[3]['shape']).all()) + + def testTFLiteGraphDefInvalid(self): + # Tests invalid cases for the model that cannot be loaded in TensorFlow. + self._initObjectDetectionArgs() + + # Missing `input_shapes`. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph( + self._graph_def_file, self._input_arrays, self._output_arrays) + self.assertEqual('input_shapes must be defined for this model.', + str(error.exception)) + + # `input_shapes` does not contain the names in `input_arrays`. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph( + self._graph_def_file, + self._input_arrays, + self._output_arrays, + input_shapes={'invalid-value': [1, 19]}) + self.assertEqual( + 'input_shapes must contain a value for each item in input_array.', + str(error.exception)) + class FromSavedModelTest(test_util.TensorFlowTestCase): |