aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-26 10:56:16 -0700
committerGravatar Michael Case <mikecase@google.com>2018-06-26 10:56:16 -0700
commit7f1056bcc9af72f6ed68939423362e390ce6ad8b (patch)
treecc434c644a508ac442f79d4463f72c929a017444 /tensorflow/contrib/lite/python
parent343b373e3386f11a16a5216574492ca56bfd7050 (diff)
parentf2813bf6e4f7f415f012307a03fd5b9fb5822d28 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/lite.py45
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py276
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py7
3 files changed, 327 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 69a2f638af..a4229f91f5 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -50,6 +50,7 @@ from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: di
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework.importer import import_graph_def
@@ -269,6 +270,48 @@ class TocoConverter(object):
return cls(
graph_def=result[0], input_tensors=result[1], output_tensors=result[2])
+ @classmethod
+ def from_keras_model_file(cls,
+ model_file,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None):
+ """Creates a TocoConverter class from a tf.keras model file.
+
+ Args:
+ model_file: Full filepath of HDF5 file containing the tf.keras model.
+ input_arrays: List of input tensors to freeze graph with. Uses input
+ arrays from SignatureDef when none are provided. (default None)
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" :
+ None}). (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+
+ Returns:
+ TocoConverter class.
+ """
+ _keras.backend.clear_session()
+ _keras.backend.set_learning_phase(False)
+ keras_model = _keras.models.load_model(model_file)
+ sess = _keras.backend.get_session()
+
+ # Get input and output tensors.
+ if input_arrays:
+ input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+ else:
+ input_tensors = keras_model.inputs
+
+ if output_arrays:
+ output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+ else:
+ output_tensors = keras_model.outputs
+ set_tensor_shapes(input_tensors, input_shapes)
+
+ graph_def = _freeze_graph(sess, output_tensors)
+ return cls(graph_def, input_tensors, output_tensors)
+
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
@@ -366,7 +409,7 @@ def _is_frozen_graph(sess):
Bool.
"""
for op in sess.graph.get_operations():
- if op.type.startswith("Variable"):
+ if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
return False
return True
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index a9475de474..ca2af5aaed 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
import os
+import tempfile
import numpy as np
from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python.interpreter import Interpreter
+from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -618,5 +620,279 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
+class FromKerasFile(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ keras.backend.clear_session()
+
+ def _getSequentialModel(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
+
+ def testSequentialModel(self):
+ """Test a Sequential tf.keras model with default inputs."""
+ keras_file = self._getSequentialModel()
+
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testSequentialModelInputArray(self):
+ """Test a Sequential tf.keras model testing input arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid input array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['invalid-input'])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ # Valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['dense_input'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testSequentialModelInputShape(self):
+ """Test a Sequential tf.keras model testing input shapes argument."""
+ keras_file = self._getSequentialModel()
+
+ # Passing in shape of invalid input array has no impact as long as all input
+ # arrays have a shape.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'invalid-input': [2, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Passing in shape of valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'dense_input': [2, 3]})
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ # Check input shape from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertTrue(([2, 3] == input_details[0]['shape']).all())
+
+ def testSequentialModelOutputArray(self):
+ """Test a Sequential tf.keras model testing output arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid output array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['invalid-output'])
+ self.assertEqual("Invalid tensors 'invalid-output' were found.",
+ str(error.exception))
+
+ # Valid output array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['time_distributed/Reshape_1'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testFunctionalModel(self):
+ """Test a Functional tf.keras model with default inputs."""
+ inputs = keras.layers.Input(shape=(3,), name='input')
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFunctionalModelMultipleInputs(self):
+ """Test a Functional tf.keras model with multiple inputs and outputs."""
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.mae],
+ loss_weights=[1., 0.5])
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ model.predict([input_a_np, input_b_np], batch_size=5)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('input_a', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ self.assertEqual('input_b', input_details[1]['name'])
+ self.assertEqual(np.float32, input_details[1]['dtype'])
+ self.assertTrue(([1, 3] == input_details[1]['shape']).all())
+ self.assertEqual((0., 0.), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(2, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('dropout/Identity', output_details[1]['name'])
+ self.assertEqual(np.float32, output_details[1]['dtype'])
+ self.assertTrue(([1, 4] == output_details[1]['shape']).all())
+ self.assertEqual((0., 0.), output_details[1]['quantization'])
+
+ def testFunctionalSequentialModel(self):
+ """Test a Functional tf.keras model containing a Sequential model."""
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model = keras.models.Model(model.input, model.output)
+
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index d18a29834b..249b940f92 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -74,6 +74,9 @@ def _get_toco_converter(flags):
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
converter_kwargs["signature_key"] = flags.saved_model_signature_key
+ elif flags.keras_model_file:
+ converter_fn = lite.TocoConverter.from_keras_model_file
+ converter_kwargs["model_file"] = flags.keras_model_file
return converter_fn(**converter_kwargs)
@@ -227,6 +230,10 @@ def run_main(_):
"--saved_model_dir",
type=str,
help="Full filepath of directory containing the SavedModel.")
+ input_file_group.add_argument(
+ "--keras_model_file",
+ type=str,
+ help="Full filepath of HDF5 file containing tf.Keras model.")
# Model format flags.
parser.add_argument(