aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-09-05 16:24:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 16:35:01 -0700
commit2c8bc1587e9480a44c10146d0e9472c1d6f9c7d7 (patch)
tree64d5c7c08f0663bee167886111abfe0e06f7e664
parentbded7fb63e932c7a7139a32d0e958479d90dbc1d (diff)
Fix lite_test.py.
PiperOrigin-RevId: 211719399
-rw-r--r--tensorflow/contrib/lite/python/BUILD2
-rw-r--r--tensorflow/contrib/lite/python/lite.py15
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py19
3 files changed, 27 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 6e30251eff..57e1290e07 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -70,7 +70,7 @@ py_library(
py_test(
name = "lite_test",
srcs = ["lite_test.py"],
- data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pbtxt"],
+ data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2de97fec86..44dfb97b84 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -58,6 +58,7 @@ from tensorflow.python.framework import graph_util as _tf_graph_util
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
@@ -225,8 +226,10 @@ class TocoConverter(object):
TocoConverter class.
Raises:
- ValueError:
+ IOError:
+ File not found.
Unable to parse input file.
+ ValueError:
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
input_shapes is not correctly defined when required
@@ -234,10 +237,13 @@ class TocoConverter(object):
with _ops.Graph().as_default():
with _session.Session() as sess:
# Read GraphDef from file.
- graph_def = _graph_pb2.GraphDef()
- with open(graph_def_file, "rb") as f:
+ if not _file_io.file_exists(graph_def_file):
+ raise IOError("File '{0}' does not exist.".format(graph_def_file))
+ with _file_io.FileIO(graph_def_file, "rb") as f:
file_content = f.read()
+
try:
+ graph_def = _graph_pb2.GraphDef()
graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
try:
@@ -248,9 +254,10 @@ class TocoConverter(object):
file_content = file_content.decode("utf-8")
else:
file_content = file_content.encode("utf-8")
+ graph_def = _graph_pb2.GraphDef()
_text_format.Merge(file_content, graph_def)
except (_text_format.ParseError, DecodeError):
- raise ValueError(
+ raise IOError(
"Unable to parse input file '{}'.".format(graph_def_file))
# Handles models with custom TFLite ops that cannot be resolved in
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 1c94ba605a..3f8ea433ff 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -521,14 +521,21 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
- def testInvalidFile(self):
+ def testInvalidFileNotFound(self):
+ with self.assertRaises(IOError) as error:
+ lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+ ['add'])
+ self.assertEqual('File \'invalid_file\' does not exist.',
+ str(error.exception))
+
+ def testInvalidFileBadData(self):
graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
with gfile.Open(graph_def_file, 'wb') as temp_file:
temp_file.write('bad data')
temp_file.flush()
# Attempts to convert the invalid model.
- with self.assertRaises(ValueError) as error:
+ with self.assertRaises(IOError) as error:
lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
['add'])
self.assertEqual(
@@ -539,7 +546,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
def _initObjectDetectionArgs(self):
# Initializes the arguments required for the object detection model.
self._graph_def_file = resource_loader.get_path_to_datafile(
- 'testdata/tflite_graph.pbtxt')
+ 'testdata/tflite_graph.pb')
self._input_arrays = ['normalized_input_image_tensor']
self._output_arrays = [
'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
@@ -586,7 +593,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
output_details[3]['name'])
self.assertTrue(([1] == output_details[3]['shape']).all())
- def testTFLiteGraphDefInvalid(self):
+ def testTFLiteGraphDefMissingShape(self):
# Tests invalid cases for the model that cannot be loaded in TensorFlow.
self._initObjectDetectionArgs()
@@ -597,6 +604,10 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
self.assertEqual('input_shapes must be defined for this model.',
str(error.exception))
+ def testTFLiteGraphDefInvalidShape(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
# `input_shapes` does not contain the names in `input_arrays`.
with self.assertRaises(ValueError) as error:
lite.TocoConverter.from_frozen_graph(