diff options
author | 2018-06-26 10:56:16 -0700 | |
---|---|---|
committer | 2018-06-26 10:56:16 -0700 | |
commit | 7f1056bcc9af72f6ed68939423362e390ce6ad8b (patch) | |
tree | cc434c644a508ac442f79d4463f72c929a017444 /tensorflow/contrib/lite/python | |
parent | 343b373e3386f11a16a5216574492ca56bfd7050 (diff) | |
parent | f2813bf6e4f7f415f012307a03fd5b9fb5822d28 (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 45 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 276 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/tflite_convert.py | 7 |
3 files changed, 327 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 69a2f638af..a4229f91f5 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -50,6 +50,7 @@ from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: di from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python import keras as _keras from tensorflow.python.client import session as _session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework.importer import import_graph_def @@ -269,6 +270,48 @@ class TocoConverter(object): return cls( graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) + @classmethod + def from_keras_model_file(cls, + model_file, + input_arrays=None, + input_shapes=None, + output_arrays=None): + """Creates a TocoConverter class from a tf.keras model file. + + Args: + model_file: Full filepath of HDF5 file containing the tf.keras model. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + + Returns: + TocoConverter class. + """ + _keras.backend.clear_session() + _keras.backend.set_learning_phase(False) + keras_model = _keras.models.load_model(model_file) + sess = _keras.backend.get_session() + + # Get input and output tensors. + if input_arrays: + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + else: + input_tensors = keras_model.inputs + + if output_arrays: + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + else: + output_tensors = keras_model.outputs + set_tensor_shapes(input_tensors, input_shapes) + + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + def convert(self): """Converts a TensorFlow GraphDef based on instance variables. @@ -366,7 +409,7 @@ def _is_frozen_graph(sess): Bool. """ for op in sess.graph.get_operations(): - if op.type.startswith("Variable"): + if op.type.startswith("Variable") or op.type.endswith("VariableOp"): return False return True diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index a9475de474..ca2af5aaed 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function import os +import tempfile import numpy as np from tensorflow.contrib.lite.python import lite from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -618,5 +620,279 @@ class FromSavedModelTest(test_util.TensorFlowTestCase): self.assertTrue(tflite_model) +class FromKerasFile(test_util.TensorFlowTestCase): + + def setUp(self): + keras.backend.clear_session() + + def _getSequentialModel(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + try: + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + finally: + os.close(fd) + return keras_file + + def testSequentialModel(self): + """Test a Sequential tf.keras model with default inputs.""" + keras_file = self._getSequentialModel() + + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testSequentialModelInputArray(self): + """Test a Sequential tf.keras model testing input arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid input array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['invalid-input']) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_arrays=['dense_input']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testSequentialModelInputShape(self): + """Test a Sequential tf.keras model testing input shapes argument.""" + keras_file = self._getSequentialModel() + + # Passing in shape of invalid input array has no impact as long as all input + # arrays have a shape. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'invalid-input': [2, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Passing in shape of valid input array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, input_shapes={'dense_input': [2, 3]}) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + # Check input shape from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertTrue(([2, 3] == input_details[0]['shape']).all()) + + def testSequentialModelOutputArray(self): + """Test a Sequential tf.keras model testing output arrays argument.""" + keras_file = self._getSequentialModel() + + # Invalid output array raises error. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['invalid-output']) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) + + # Valid output array. + converter = lite.TocoConverter.from_keras_model_file( + keras_file, output_arrays=['time_distributed/Reshape_1']) + tflite_model = converter.convert() + os.remove(keras_file) + self.assertTrue(tflite_model) + + def testFunctionalModel(self): + """Test a Functional tf.keras model with default inputs.""" + inputs = keras.layers.Input(shape=(3,), name='input') + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFunctionalModelMultipleInputs(self): + """Test a Functional tf.keras model with multiple inputs and outputs.""" + a = keras.layers.Input(shape=(3,), name='input_a') + b = keras.layers.Input(shape=(3,), name='input_b') + dense = keras.layers.Dense(4, name='dense') + c = dense(a) + d = dense(b) + e = keras.layers.Dropout(0.5, name='dropout')(c) + + model = keras.models.Model([a, b], [d, e]) + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.mae], + loss_weights=[1., 0.5]) + + input_a_np = np.random.random((10, 3)) + input_b_np = np.random.random((10, 3)) + output_d_np = np.random.random((10, 4)) + output_e_np = np.random.random((10, 4)) + model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) + + model.predict([input_a_np, input_b_np], batch_size=5) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('input_a', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('input_b', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(2, len(output_details)) + self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + self.assertEqual('dropout/Identity', output_details[1]['name']) + self.assertEqual(np.float32, output_details[1]['dtype']) + self.assertTrue(([1, 4] == output_details[1]['shape']).all()) + self.assertEqual((0., 0.), output_details[1]['quantization']) + + def testFunctionalSequentialModel(self): + """Test a Functional tf.keras model containing a Sequential model.""" + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model = keras.models.Model(model.input, model.output) + + model.compile( + loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + model.predict(x) + + model.predict(x) + fd, keras_file = tempfile.mkstemp('.h5') + keras.models.save_model(model, keras_file) + + # Convert to TFLite model. + converter = lite.TocoConverter.from_keras_model_file(keras_file) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + os.close(fd) + os.remove(keras_file) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('dense_input', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index d18a29834b..249b940f92 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -74,6 +74,9 @@ def _get_toco_converter(flags): converter_kwargs["saved_model_dir"] = flags.saved_model_dir converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) converter_kwargs["signature_key"] = flags.saved_model_signature_key + elif flags.keras_model_file: + converter_fn = lite.TocoConverter.from_keras_model_file + converter_kwargs["model_file"] = flags.keras_model_file return converter_fn(**converter_kwargs) @@ -227,6 +230,10 @@ def run_main(_): "--saved_model_dir", type=str, help="Full filepath of directory containing the SavedModel.") + input_file_group.add_argument( + "--keras_model_file", + type=str, + help="Full filepath of HDF5 file containing tf.Keras model.") # Model format flags. parser.add_argument( |