aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 1c94ba605a..3f8ea433ff 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -521,14 +521,21 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
- def testInvalidFile(self):
+ def testInvalidFileNotFound(self):
+ with self.assertRaises(IOError) as error:
+ lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+ ['add'])
+ self.assertEqual('File \'invalid_file\' does not exist.',
+ str(error.exception))
+
+ def testInvalidFileBadData(self):
graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
with gfile.Open(graph_def_file, 'wb') as temp_file:
temp_file.write('bad data')
temp_file.flush()
# Attempts to convert the invalid model.
- with self.assertRaises(ValueError) as error:
+ with self.assertRaises(IOError) as error:
lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
['add'])
self.assertEqual(
@@ -539,7 +546,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
def _initObjectDetectionArgs(self):
# Initializes the arguments required for the object detection model.
self._graph_def_file = resource_loader.get_path_to_datafile(
- 'testdata/tflite_graph.pbtxt')
+ 'testdata/tflite_graph.pb')
self._input_arrays = ['normalized_input_image_tensor']
self._output_arrays = [
'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
@@ -586,7 +593,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
output_details[3]['name'])
self.assertTrue(([1] == output_details[3]['shape']).all())
- def testTFLiteGraphDefInvalid(self):
+ def testTFLiteGraphDefMissingShape(self):
# Tests invalid cases for the model that cannot be loaded in TensorFlow.
self._initObjectDetectionArgs()
@@ -597,6 +604,10 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
self.assertEqual('input_shapes must be defined for this model.',
str(error.exception))
+ def testTFLiteGraphDefInvalidShape(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
# `input_shapes` does not contain the names in `input_arrays`.
with self.assertRaises(ValueError) as error:
lite.TocoConverter.from_frozen_graph(