aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/g3doc/python_api.md
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-06-25 16:09:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 16:12:36 -0700
commit3ce0d77e0c80bf5d2568fdeefd3042b62c96079f (patch)
tree0d0c279904efb03113f15d6f195286e64a19957b /tensorflow/contrib/lite/toco/g3doc/python_api.md
parentf2460fc21b22b65ca57c7ea996e4e8d003aa3371 (diff)
Adds tf.keras support to TocoConverter.
PiperOrigin-RevId: 202037381
Diffstat (limited to 'tensorflow/contrib/lite/toco/g3doc/python_api.md')
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md46
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index afa6fd6957..b04d166f89 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -15,6 +15,7 @@ Table of contents:
* [Exporting a GraphDef from tf.Session](#basic-graphdef-sess)
* [Exporting a GraphDef from file](#basic-graphdef-file)
* [Exporting a SavedModel](#basic-savedmodel)
+ * [Exporting a tf.keras File](#basic-keras-file)
* [Complex examples](#complex)
* [Exporting a quantized GraphDef](#complex-quant)
* [TensorFlow Lite Python interpreter](#interpreter)
@@ -114,6 +115,51 @@ For more complex SavedModels, the optional parameters that can be passed into
`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are
available by running `help(tf.contrib.lite.TocoConverter)`.
+### Exporting a tf.keras File <a name="basic-keras-file"></a>
+
+The following example shows how to convert a tf.keras model into a TensorFlow
+Lite FlatBuffer.
+
+```python
+import tensorflow as tf
+
+converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5")
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
+```
+
+The tf.keras file must contain both the model and the weights. A comprehensive
+example including model construction can be seen below.
+
+```python
+import numpy as np
+import tensorflow as tf
+
+# Generate tf.keras model.
+model = tf.keras.models.Sequential()
+model.add(tf.keras.layers.Dense(2, input_shape=(3,)))
+model.add(tf.keras.layers.RepeatVector(3))
+model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)))
+model.compile(loss=tf.keras.losses.MSE,
+ optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[tf.keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+
+x = np.random.random((1, 3))
+y = np.random.random((1, 3, 3))
+model.train_on_batch(x, y)
+model.predict(x)
+
+# Save tf.keras model in HDF5 format.
+keras_file = "keras_model.h5"
+tf.keras.models.save_model(model, keras_file)
+
+# Convert to TensorFlow Lite model.
+converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file)
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
+```
+
## Complex examples <a name="complex"></a>
For models where the default value of the attributes is not sufficient, the