From 2c8bc1587e9480a44c10146d0e9472c1d6f9c7d7 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Wed, 5 Sep 2018 16:24:29 -0700 Subject: Fix lite_test.py. PiperOrigin-RevId: 211719399 --- tensorflow/contrib/lite/python/BUILD | 2 +- tensorflow/contrib/lite/python/lite.py | 15 +++++++++++---- tensorflow/contrib/lite/python/lite_test.py | 19 +++++++++++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) (limited to 'tensorflow/contrib/lite/python') diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 6e30251eff..57e1290e07 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -70,7 +70,7 @@ py_library( py_test( name = "lite_test", srcs = ["lite_test.py"], - data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pbtxt"], + data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"], srcs_version = "PY2AND3", tags = [ "no_oss", diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 2de97fec86..44dfb97b84 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -58,6 +58,7 @@ from tensorflow.python.framework import graph_util as _tf_graph_util from tensorflow.python.framework import ops as _ops from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError from tensorflow.python.framework.importer import import_graph_def as _import_graph_def +from tensorflow.python.lib.io import file_io as _file_io from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants @@ -225,8 +226,10 @@ class TocoConverter(object): TocoConverter class. Raises: - ValueError: + IOError: + File not found. Unable to parse input file. + ValueError: The graph is not frozen. input_arrays or output_arrays contains an invalid tensor name. input_shapes is not correctly defined when required @@ -234,10 +237,13 @@ class TocoConverter(object): with _ops.Graph().as_default(): with _session.Session() as sess: # Read GraphDef from file. - graph_def = _graph_pb2.GraphDef() - with open(graph_def_file, "rb") as f: + if not _file_io.file_exists(graph_def_file): + raise IOError("File '{0}' does not exist.".format(graph_def_file)) + with _file_io.FileIO(graph_def_file, "rb") as f: file_content = f.read() + try: + graph_def = _graph_pb2.GraphDef() graph_def.ParseFromString(file_content) except (_text_format.ParseError, DecodeError): try: @@ -248,9 +254,10 @@ class TocoConverter(object): file_content = file_content.decode("utf-8") else: file_content = file_content.encode("utf-8") + graph_def = _graph_pb2.GraphDef() _text_format.Merge(file_content, graph_def) except (_text_format.ParseError, DecodeError): - raise ValueError( + raise IOError( "Unable to parse input file '{}'.".format(graph_def_file)) # Handles models with custom TFLite ops that cannot be resolved in diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 1c94ba605a..3f8ea433ff 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -521,14 +521,21 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) - def testInvalidFile(self): + def testInvalidFileNotFound(self): + with self.assertRaises(IOError) as error: + lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'], + ['add']) + self.assertEqual('File \'invalid_file\' does not exist.', + str(error.exception)) + + def testInvalidFileBadData(self): graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') with gfile.Open(graph_def_file, 'wb') as temp_file: temp_file.write('bad data') temp_file.flush() # Attempts to convert the invalid model. - with self.assertRaises(ValueError) as error: + with self.assertRaises(IOError) as error: lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], ['add']) self.assertEqual( @@ -539,7 +546,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): def _initObjectDetectionArgs(self): # Initializes the arguments required for the object detection model. self._graph_def_file = resource_loader.get_path_to_datafile( - 'testdata/tflite_graph.pbtxt') + 'testdata/tflite_graph.pb') self._input_arrays = ['normalized_input_image_tensor'] self._output_arrays = [ 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', @@ -586,7 +593,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): output_details[3]['name']) self.assertTrue(([1] == output_details[3]['shape']).all()) - def testTFLiteGraphDefInvalid(self): + def testTFLiteGraphDefMissingShape(self): # Tests invalid cases for the model that cannot be loaded in TensorFlow. self._initObjectDetectionArgs() @@ -597,6 +604,10 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): self.assertEqual('input_shapes must be defined for this model.', str(error.exception)) + def testTFLiteGraphDefInvalidShape(self): + # Tests invalid cases for the model that cannot be loaded in TensorFlow. + self._initObjectDetectionArgs() + # `input_shapes` does not contain the names in `input_arrays`. with self.assertRaises(ValueError) as error: lite.TocoConverter.from_frozen_graph( -- cgit v1.2.3