diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/export.h')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/export.h | 54 |
1 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 58ea5c725c..b070a38768 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -23,22 +23,55 @@ namespace toco { namespace tflite { +// The parameters for exporting a TFLite model. +struct ExportParams { + bool allow_custom_ops = false; + bool allow_eager_ops = false; + bool quantize_weights = false; +}; + // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the // result in the given string. -void Export(const Model& model, bool allow_custom_ops, - string* output_file_contents); - -// This if backward-compatibility. -// TODO(ycling): Remove the deprecated entry functions. -inline void Export(const Model& model, string* output_file_contents) { - Export(model, true, output_file_contents); -} +void Export(const Model& model, string* output_file_contents, + const ExportParams& params); // Export API with custom TFLite operator mapping. void Export( - const Model& model, bool allow_custom_ops, string* output_file_contents, + const Model& model, string* output_file_contents, + const ExportParams& params, const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export(const Model& model, bool allow_custom_ops, + bool quantize_weights, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params); +} + +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export( + const Model& model, bool allow_custom_ops, bool quantize_weights, + string* output_file_contents, + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params, ops_by_type); +} + +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export(const Model& model, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = true; + Export(model, output_file_contents, params); + Export(model, true, false, output_file_contents); +} + namespace details { // A maps from tensor name to its final position in the TF Lite buffer. @@ -87,7 +120,8 @@ using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>; void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, - const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type); + const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type, + bool allow_eager_ops); } // namespace details } // namespace tflite |