diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-06-25 16:09:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-25 16:12:36 -0700 |
commit | 3ce0d77e0c80bf5d2568fdeefd3042b62c96079f (patch) | |
tree | 0d0c279904efb03113f15d6f195286e64a19957b /tensorflow/contrib/lite/toco/g3doc/python_api.md | |
parent | f2460fc21b22b65ca57c7ea996e4e8d003aa3371 (diff) |
Adds tf.keras support to TocoConverter.
PiperOrigin-RevId: 202037381
Diffstat (limited to 'tensorflow/contrib/lite/toco/g3doc/python_api.md')
-rw-r--r-- | tensorflow/contrib/lite/toco/g3doc/python_api.md | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index afa6fd6957..b04d166f89 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -15,6 +15,7 @@ Table of contents: * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) * [Exporting a GraphDef from file](#basic-graphdef-file) * [Exporting a SavedModel](#basic-savedmodel) + * [Exporting a tf.keras File](#basic-keras-file) * [Complex examples](#complex) * [Exporting a quantized GraphDef](#complex-quant) * [TensorFlow Lite Python interpreter](#interpreter) @@ -114,6 +115,51 @@ For more complex SavedModels, the optional parameters that can be passed into `output_arrays`, `tag_set` and `signature_key`. Details of each parameter are available by running `help(tf.contrib.lite.TocoConverter)`. +### Exporting a tf.keras File <a name="basic-keras-file"></a> + +The following example shows how to convert a tf.keras model into a TensorFlow +Lite FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5") +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +The tf.keras file must contain both the model and the weights. A comprehensive +example including model construction can be seen below. + +```python +import numpy as np +import tensorflow as tf + +# Generate tf.keras model. +model = tf.keras.models.Sequential() +model.add(tf.keras.layers.Dense(2, input_shape=(3,))) +model.add(tf.keras.layers.RepeatVector(3)) +model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3))) +model.compile(loss=tf.keras.losses.MSE, + optimizer=tf.keras.optimizers.RMSprop(lr=0.0001), + metrics=[tf.keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + +x = np.random.random((1, 3)) +y = np.random.random((1, 3, 3)) +model.train_on_batch(x, y) +model.predict(x) + +# Save tf.keras model in HDF5 format. +keras_file = "keras_model.h5" +tf.keras.models.save_model(model, keras_file) + +# Convert to TensorFlow Lite model. +converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + ## Complex examples <a name="complex"></a> For models where the default value of the attributes is not sufficient, the |