aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/python/lite.py')
-rw-r--r--tensorflow/contrib/lite/python/lite.py65
1 files changed, 31 insertions, 34 deletions
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index a4229f91f5..2f9b9d469a 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -40,24 +40,23 @@ from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
-from tensorflow.contrib.lite.python.convert import tensor_name
+from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
-from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model
-from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names
-from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes
+from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
+from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
+from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes as _set_tensor_shapes
from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
-from tensorflow.python.framework import graph_util as tf_graph_util
-from tensorflow.python.framework.importer import import_graph_def
-from tensorflow.python.ops.variables import global_variables_initializer
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import tag_constants
-# from tensorflow.python.util.all_util import remove_undocumented
+from tensorflow.python.framework import graph_util as _tf_graph_util
+from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
+from tensorflow.python.saved_model import signature_constants as _signature_constants
+from tensorflow.python.saved_model import tag_constants as _tag_constants
class TocoConverter(object):
@@ -132,7 +131,7 @@ class TocoConverter(object):
Args:
- graph_def: TensorFlow GraphDef.
+ graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
@@ -178,7 +177,7 @@ class TocoConverter(object):
"""Creates a TocoConverter class from a file containing a frozen GraphDef.
Args:
- graph_def_file: Full filepath of file containing TensorFlow GraphDef.
+ graph_def_file: Full filepath of file containing frozen GraphDef.
input_arrays: List of input tensors to freeze graph with.
output_arrays: List of output tensors to freeze graph with.
input_shapes: Dict of strings representing input tensor names to list of
@@ -196,7 +195,7 @@ class TocoConverter(object):
input_arrays or output_arrays contains an invalid tensor name.
"""
with _session.Session() as sess:
- sess.run(global_variables_initializer())
+ sess.run(_global_variables_initializer())
# Read GraphDef from file.
graph_def = _graph_pb2.GraphDef()
@@ -218,12 +217,12 @@ class TocoConverter(object):
raise ValueError(
"Unable to parse input file '{}'.".format(graph_def_file))
sess.graph.as_default()
- import_graph_def(graph_def, name="")
+ _import_graph_def(graph_def, name="")
# Get input and output tensors.
- input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
- output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
- set_tensor_shapes(input_tensors, input_shapes)
+ input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
+ output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
+ _set_tensor_shapes(input_tensors, input_shapes)
# Check if graph is frozen.
if not _is_frozen_graph(sess):
@@ -261,12 +260,12 @@ class TocoConverter(object):
TocoConverter class.
"""
if tag_set is None:
- tag_set = set([tag_constants.SERVING])
+ tag_set = set([_tag_constants.SERVING])
if signature_key is None:
- signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
- output_arrays, tag_set, signature_key)
+ result = _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
+ output_arrays, tag_set, signature_key)
return cls(
graph_def=result[0], input_tensors=result[1], output_tensors=result[2])
@@ -299,15 +298,15 @@ class TocoConverter(object):
# Get input and output tensors.
if input_arrays:
- input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+ input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
else:
input_tensors = keras_model.inputs
if output_arrays:
- output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+ output_tensors = _get_tensors_from_tensor_names(sess.graph, output_arrays)
else:
output_tensors = keras_model.outputs
- set_tensor_shapes(input_tensors, input_shapes)
+ _set_tensor_shapes(input_tensors, input_shapes)
graph_def = _freeze_graph(sess, output_tensors)
return cls(graph_def, input_tensors, output_tensors)
@@ -328,12 +327,12 @@ class TocoConverter(object):
for tensor in self._input_tensors:
if not tensor.get_shape():
raise ValueError("Provide an input shape for input array '{0}'.".format(
- tensor_name(tensor)))
+ _tensor_name(tensor)))
shape = tensor.get_shape().as_list()
if None in shape[1:]:
raise ValueError(
"None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(tensor_name(tensor), shape))
+ "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
elif shape[0] is None:
self._set_batch_size(batch_size=1)
@@ -343,7 +342,7 @@ class TocoConverter(object):
quantized_stats = []
invalid_stats = []
for tensor in self._input_tensors:
- name = tensor_name(tensor)
+ name = _tensor_name(tensor)
if name in self.quantized_input_stats:
quantized_stats.append(self.quantized_input_stats[name])
else:
@@ -381,7 +380,7 @@ class TocoConverter(object):
Returns:
List of strings.
"""
- return [tensor_name(tensor) for tensor in self._input_tensors]
+ return [_tensor_name(tensor) for tensor in self._input_tensors]
def _set_batch_size(self, batch_size):
"""Sets the first dimension of the input tensor to `batch_size`.
@@ -428,11 +427,9 @@ def _freeze_graph(sess, output_tensors):
Frozen GraphDef.
"""
if not _is_frozen_graph(sess):
- sess.run(global_variables_initializer())
- output_arrays = [tensor_name(tensor) for tensor in output_tensors]
- return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def,
- output_arrays)
+ sess.run(_global_variables_initializer())
+ output_arrays = [_tensor_name(tensor) for tensor in output_tensors]
+ return _tf_graph_util.convert_variables_to_constants(
+ sess, sess.graph_def, output_arrays)
else:
return sess.graph_def
-
-# remove_undocumented(__name__)