diff options
Diffstat (limited to 'tensorflow/contrib/lite/python/lite.py')
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 171 |
1 files changed, 126 insertions, 45 deletions
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 2313bfa3b6..a4c9a2381c 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -42,6 +42,7 @@ from tensorflow.contrib.lite.python import lite_constants as constants from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model @@ -55,6 +56,7 @@ from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as _tf_graph_util from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError from tensorflow.python.framework.importer import import_graph_def as _import_graph_def from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants @@ -133,7 +135,12 @@ class TocoConverter(object): ``` """ - def __init__(self, graph_def, input_tensors, output_tensors): + def __init__(self, + graph_def, + input_tensors, + output_tensors, + input_arrays_with_shape=None, + output_arrays=None): """Constructor for TocoConverter. Args: @@ -142,6 +149,17 @@ class TocoConverter(object): input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). + input_arrays_with_shape: Tuple of strings representing input tensor names + and list of integers representing input shapes + (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded + into TensorFlow and when `input_tensors` and `output_tensors` are None. + (default None) + output_arrays: List of output tensors to freeze graph with. Use only when + graph cannot be loaded into TensorFlow and when `input_tensors` and + `output_tensors` are None. (default None) + + Raises: + ValueError: Invalid arguments. """ self._graph_def = graph_def self._input_tensors = input_tensors @@ -159,6 +177,15 @@ class TocoConverter(object): self.dump_graphviz_dir = None self.dump_graphviz_video = False + # Attributes are used by models that cannot be loaded into TensorFlow. + if not self._has_valid_tensors(): + if not input_arrays_with_shape or not output_arrays: + raise ValueError( + "If input_tensors and output_tensors are None, both " + "input_arrays_with_shape and output_arrays must be defined.") + self._input_arrays_with_shape = input_arrays_with_shape + self._output_arrays = output_arrays + @classmethod def from_session(cls, sess, input_tensors, output_tensors): """Creates a TocoConverter class from a TensorFlow Session. @@ -200,6 +227,7 @@ class TocoConverter(object): Unable to parse input file. The graph is not frozen. input_arrays or output_arrays contains an invalid tensor name. + input_shapes is not correctly defined when required """ with _ops.Graph().as_default(): with _session.Session() as sess: @@ -222,20 +250,44 @@ class TocoConverter(object): except (_text_format.ParseError, DecodeError): raise ValueError( "Unable to parse input file '{}'.".format(graph_def_file)) - _import_graph_def(graph_def, name="") - - # Get input and output tensors. - input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays) - output_tensors = _get_tensors_from_tensor_names(sess.graph, - output_arrays) - _set_tensor_shapes(input_tensors, input_shapes) - - # Check if graph is frozen. - if not _is_frozen_graph(sess): - raise ValueError("Please freeze the graph using freeze_graph.py.") - # Create TocoConverter class. - return cls(sess.graph_def, input_tensors, output_tensors) + # Handles models with custom TFLite ops that cannot be resolved in + # TensorFlow. + load_model_in_session = True + try: + _import_graph_def(graph_def, name="") + except _NotFoundError: + load_model_in_session = False + + if load_model_in_session: + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py.") + + # Get input and output tensors. + input_tensors = _get_tensors_from_tensor_names( + sess.graph, input_arrays) + output_tensors = _get_tensors_from_tensor_names( + sess.graph, output_arrays) + _set_tensor_shapes(input_tensors, input_shapes) + + return cls(sess.graph_def, input_tensors, output_tensors) + else: + if not input_shapes: + raise ValueError("input_shapes must be defined for this model.") + if set(input_arrays) != set(input_shapes.keys()): + raise ValueError("input_shapes must contain a value for each item " + "in input_array.") + + input_arrays_with_shape = [ + (name, input_shapes[name]) for name in input_arrays + ] + return cls( + graph_def, + input_tensors=None, + output_tensors=None, + input_arrays_with_shape=input_arrays_with_shape, + output_arrays=output_arrays) @classmethod def from_saved_model(cls, @@ -330,25 +382,25 @@ class TocoConverter(object): None value for dimension in input_tensor. """ # Checks dimensions in input tensor. - for tensor in self._input_tensors: - if not tensor.get_shape(): - raise ValueError("Provide an input shape for input array '{0}'.".format( - _tensor_name(tensor))) - shape = tensor.get_shape().as_list() - if None in shape[1:]: - raise ValueError( - "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(_tensor_name(tensor), shape)) - elif shape[0] is None: - self._set_batch_size(batch_size=1) + if self._has_valid_tensors(): + for tensor in self._input_tensors: + if not tensor.get_shape(): + raise ValueError("Provide an input shape for input array " + "'{0}'.".format(_tensor_name(tensor))) + shape = tensor.get_shape().as_list() + if None in shape[1:]: + raise ValueError( + "None is only supported in the 1st dimension. Tensor '{0}' has " + "invalid shape '{1}'.".format(_tensor_name(tensor), shape)) + elif shape[0] is None: + self._set_batch_size(batch_size=1) # Get quantization stats. Ensures there is one stat per name if the stats # are specified. if self.quantized_input_stats: quantized_stats = [] invalid_stats = [] - for tensor in self._input_tensors: - name = _tensor_name(tensor) + for name in self.get_input_arrays(): if name in self.quantized_input_stats: quantized_stats.append(self.quantized_input_stats[name]) else: @@ -360,24 +412,35 @@ class TocoConverter(object): else: quantized_stats = None + converter_kwargs = { + "inference_type": self.inference_type, + "inference_input_type": self.inference_input_type, + "input_format": constants.TENSORFLOW_GRAPHDEF, + "output_format": self.output_format, + "quantized_input_stats": quantized_stats, + "default_ranges_stats": self.default_ranges_stats, + "drop_control_dependency": self.drop_control_dependency, + "reorder_across_fake_quant": self.reorder_across_fake_quant, + "change_concat_input_ranges": self.change_concat_input_ranges, + "allow_custom_ops": self.allow_custom_ops, + "quantize_weights": self.quantize_weights, + "dump_graphviz_dir": self.dump_graphviz_dir, + "dump_graphviz_video": self.dump_graphviz_video + } + # Converts model. - result = _toco_convert_impl( - input_data=self._graph_def, - input_tensors=self._input_tensors, - output_tensors=self._output_tensors, - inference_type=self.inference_type, - inference_input_type=self.inference_input_type, - input_format=constants.TENSORFLOW_GRAPHDEF, - output_format=self.output_format, - quantized_input_stats=quantized_stats, - default_ranges_stats=self.default_ranges_stats, - drop_control_dependency=self.drop_control_dependency, - reorder_across_fake_quant=self.reorder_across_fake_quant, - change_concat_input_ranges=self.change_concat_input_ranges, - allow_custom_ops=self.allow_custom_ops, - quantize_weights=self.quantize_weights, - dump_graphviz_dir=self.dump_graphviz_dir, - dump_graphviz_video=self.dump_graphviz_video) + if self._has_valid_tensors(): + result = _toco_convert_impl( + input_data=self._graph_def, + input_tensors=self._input_tensors, + output_tensors=self._output_tensors, + **converter_kwargs) + else: + result = _toco_convert_graph_def( + input_data=self._graph_def, + input_arrays_with_shape=self._input_arrays_with_shape, + output_arrays=self._output_arrays, + **converter_kwargs) return result def get_input_arrays(self): @@ -386,7 +449,18 @@ class TocoConverter(object): Returns: List of strings. """ - return [_tensor_name(tensor) for tensor in self._input_tensors] + if self._has_valid_tensors(): + return [_tensor_name(tensor) for tensor in self._input_tensors] + else: + return [name for name, _ in self._input_arrays_with_shape] + + def _has_valid_tensors(self): + """Checks if the input and output tensors have been initialized. + + Returns: + Bool. + """ + return self._input_tensors and self._output_tensors def _set_batch_size(self, batch_size): """Sets the first dimension of the input tensor to `batch_size`. @@ -394,7 +468,14 @@ class TocoConverter(object): Args: batch_size: Batch size for the model. Replaces the first dimension of an input size array if undefined. (default 1) + + Raises: + ValueError: input_tensor is not defined. """ + if not self._has_valid_tensors(): + raise ValueError("The batch size cannot be set for this model. Please " + "use input_shapes parameter.") + for tensor in self._input_tensors: shape = tensor.get_shape().as_list() shape[0] = batch_size |