aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-08-13 16:11:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-13 16:21:09 -0700
commit859f68f6c444619de24db9bf43f2a0978997797f (patch)
treeb2df41964744eec450d5b4f749eb9aa34e573d1a /tensorflow/contrib/lite/python
parent83f1458ec1c19b3d46676ab543dff4ec401a0dd0 (diff)
Create a new graph for loading the frozen graph in TocoConverter.
PiperOrigin-RevId: 208560644
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/lite.py66
1 files changed, 34 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 52ef43d71f..5ec52035ad 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -53,6 +53,7 @@ from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _tf_graph_util
+from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
@@ -193,40 +194,41 @@ class TocoConverter(object):
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
"""
- with _session.Session() as sess:
- # Read GraphDef from file.
- graph_def = _graph_pb2.GraphDef()
- with open(graph_def_file, "rb") as f:
- file_content = f.read()
- try:
- graph_def.ParseFromString(file_content)
- except (_text_format.ParseError, DecodeError):
+ with _ops.Graph().as_default():
+ with _session.Session() as sess:
+ # Read GraphDef from file.
+ graph_def = _graph_pb2.GraphDef()
+ with open(graph_def_file, "rb") as f:
+ file_content = f.read()
try:
- print("Ignore 'tcmalloc: large alloc' warnings.")
-
- if not isinstance(file_content, str):
- if PY3:
- file_content = file_content.decode('utf-8')
- else:
- file_content = file_content.encode('utf-8')
- _text_format.Merge(file_content, graph_def)
+ graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
- raise ValueError(
- "Unable to parse input file '{}'.".format(graph_def_file))
- sess.graph.as_default()
- _import_graph_def(graph_def, name="")
-
- # Get input and output tensors.
- input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
- output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
- _set_tensor_shapes(input_tensors, input_shapes)
-
- # Check if graph is frozen.
- if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py.")
-
- # Create TocoConverter class.
- return cls(sess.graph_def, input_tensors, output_tensors)
+ try:
+ print("Ignore 'tcmalloc: large alloc' warnings.")
+
+ if not isinstance(file_content, str):
+ if PY3:
+ file_content = file_content.decode("utf-8")
+ else:
+ file_content = file_content.encode("utf-8")
+ _text_format.Merge(file_content, graph_def)
+ except (_text_format.ParseError, DecodeError):
+ raise ValueError(
+ "Unable to parse input file '{}'.".format(graph_def_file))
+ _import_graph_def(graph_def, name="")
+
+ # Get input and output tensors.
+ input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
+ output_tensors = _get_tensors_from_tensor_names(sess.graph,
+ output_arrays)
+ _set_tensor_shapes(input_tensors, input_shapes)
+
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
+
+ # Create TocoConverter class.
+ return cls(sess.graph_def, input_tensors, output_tensors)
@classmethod
def from_saved_model(cls,