aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite_test.py
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-06-25 16:09:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 16:12:36 -0700
commit3ce0d77e0c80bf5d2568fdeefd3042b62c96079f (patch)
tree0d0c279904efb03113f15d6f195286e64a19957b /tensorflow/contrib/lite/python/lite_test.py
parentf2460fc21b22b65ca57c7ea996e4e8d003aa3371 (diff)
Adds tf.keras support to TocoConverter.
PiperOrigin-RevId: 202037381
Diffstat (limited to 'tensorflow/contrib/lite/python/lite_test.py')
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py276
1 files changed, 276 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index a9475de474..ca2af5aaed 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
import os
+import tempfile
import numpy as np
from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python.interpreter import Interpreter
+from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -618,5 +620,279 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
+class FromKerasFile(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ keras.backend.clear_session()
+
+ def _getSequentialModel(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
+
+ def testSequentialModel(self):
+ """Test a Sequential tf.keras model with default inputs."""
+ keras_file = self._getSequentialModel()
+
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testSequentialModelInputArray(self):
+ """Test a Sequential tf.keras model testing input arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid input array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['invalid-input'])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ # Valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['dense_input'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testSequentialModelInputShape(self):
+ """Test a Sequential tf.keras model testing input shapes argument."""
+ keras_file = self._getSequentialModel()
+
+ # Passing in shape of invalid input array has no impact as long as all input
+ # arrays have a shape.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'invalid-input': [2, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Passing in shape of valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'dense_input': [2, 3]})
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ # Check input shape from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertTrue(([2, 3] == input_details[0]['shape']).all())
+
+ def testSequentialModelOutputArray(self):
+ """Test a Sequential tf.keras model testing output arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid output array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['invalid-output'])
+ self.assertEqual("Invalid tensors 'invalid-output' were found.",
+ str(error.exception))
+
+ # Valid output array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['time_distributed/Reshape_1'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testFunctionalModel(self):
+ """Test a Functional tf.keras model with default inputs."""
+ inputs = keras.layers.Input(shape=(3,), name='input')
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFunctionalModelMultipleInputs(self):
+ """Test a Functional tf.keras model with multiple inputs and outputs."""
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.mae],
+ loss_weights=[1., 0.5])
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ model.predict([input_a_np, input_b_np], batch_size=5)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('input_a', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ self.assertEqual('input_b', input_details[1]['name'])
+ self.assertEqual(np.float32, input_details[1]['dtype'])
+ self.assertTrue(([1, 3] == input_details[1]['shape']).all())
+ self.assertEqual((0., 0.), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(2, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('dropout/Identity', output_details[1]['name'])
+ self.assertEqual(np.float32, output_details[1]['dtype'])
+ self.assertTrue(([1, 4] == output_details[1]['shape']).all())
+ self.assertEqual((0., 0.), output_details[1]['quantization'])
+
+ def testFunctionalSequentialModel(self):
+ """Test a Functional tf.keras model containing a Sequential model."""
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model = keras.models.Model(model.input, model.output)
+
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+
if __name__ == '__main__':
test.main()