From 4a6ab2cb8c2f33ffb6b64d61bd09f006e75982c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 20 Mar 2018 13:38:09 -0700 Subject: Build tflite interpreter from buffer in python interface PiperOrigin-RevId: 189800400 --- tensorflow/contrib/lite/python/interpreter.py | 26 ++++++++--- tensorflow/contrib/lite/python/interpreter_test.py | 53 ++++++++++++---------- .../interpreter_wrapper/interpreter_wrapper.cc | 9 +++- .../interpreter_wrapper/interpreter_wrapper.h | 6 ++- 4 files changed, 62 insertions(+), 32 deletions(-) diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 5b5a7c3199..accdd04671 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -23,19 +23,33 @@ from tensorflow.contrib.lite.python.interpreter_wrapper import tensorflow_wrap_i class Interpreter(object): """Interpreter inferace for TF-Lite Models.""" - def __init__(self, model_path): + def __init__(self, model_path=None, model_content=None): """Constructor. Args: model_path: Path to TF-Lite Flatbuffer file. + model_content: Content of model. Raises: - ValueError: If the interpreter was unable to open the model. + ValueError: If the interpreter was unable to create. """ - self._interpreter = ( - interpreter_wrapper.InterpreterWrapper_CreateWrapperCPP(model_path)) - if not self._interpreter: - raise ValueError('Failed to open {}'.format(model_path)) + if model_path and not model_content: + self._interpreter = ( + interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile( + model_path)) + if not self._interpreter: + raise ValueError('Failed to open {}'.format(model_path)) + elif model_content and not model_path: + self._interpreter = ( + interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( + model_content, len(model_content))) + if not self._interpreter: + raise ValueError( + 'Failed to create model from {} bytes'.format(len(model_content))) + elif not model_path and not model_path: + raise ValueError('`model_path` or `model_content` must be specified.') + else: + raise ValueError('Can\'t both provide `model_path` and `model_content`') def allocate_tensors(self): if not self._interpreter.AllocateTensors(): diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index e0215b721c..e85390c56c 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import io import numpy as np from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper @@ -29,7 +30,8 @@ class InterpreterTest(test_util.TensorFlowTestCase): def testFloat(self): interpreter = interpreter_wrapper.Interpreter( - resource_loader.get_path_to_datafile('testdata/permute_float.tflite')) + model_path=resource_loader.get_path_to_datafile( + 'testdata/permute_float.tflite')) interpreter.allocate_tensors() input_details = interpreter.get_input_details() @@ -53,29 +55,32 @@ class InterpreterTest(test_util.TensorFlowTestCase): self.assertTrue((expected_output == output_data).all()) def testUint8(self): - interpreter = interpreter_wrapper.Interpreter( - resource_loader.get_path_to_datafile('testdata/permute_uint8.tflite')) - interpreter.allocate_tensors() - - input_details = interpreter.get_input_details() - self.assertEqual(1, len(input_details)) - self.assertEqual('input', input_details[0]['name']) - self.assertEqual(np.uint8, input_details[0]['dtype']) - self.assertTrue(([1, 4] == input_details[0]['shape']).all()) - - output_details = interpreter.get_output_details() - self.assertEqual(1, len(output_details)) - self.assertEqual('output', output_details[0]['name']) - self.assertEqual(np.uint8, output_details[0]['dtype']) - self.assertTrue(([1, 4] == output_details[0]['shape']).all()) - - test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) - expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) - interpreter.set_tensor(input_details[0]['index'], test_input) - interpreter.invoke() - - output_data = interpreter.get_tensor(output_details[0]['index']) - self.assertTrue((expected_output == output_data).all()) + model_path = resource_loader.get_path_to_datafile( + 'testdata/permute_uint8.tflite') + with io.open(model_path, 'rb') as model_file: + data = model_file.read() + interpreter = interpreter_wrapper.Interpreter(model_content=data) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('input', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 4] == input_details[0]['shape']).all()) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 4] == output_details[0]['shape']).all()) + + test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) + expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) + interpreter.set_tensor(input_details[0]['index'], test_input) + interpreter.invoke() + + output_data = interpreter.get_tensor(output_details[0]['index']) + self.assertTrue((expected_output == output_data).all()) if __name__ == '__main__': diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index f30067de94..14e1190c80 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -302,12 +302,19 @@ PyObject* InterpreterWrapper::GetTensor(int i) const { return PyArray_Return(reinterpret_cast(np_array)); } -InterpreterWrapper* InterpreterWrapper::CreateWrapperCPP( +InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( const char* model_path) { std::unique_ptr model = tflite::FlatBufferModel::BuildFromFile(model_path); return model ? new InterpreterWrapper(std::move(model)) : nullptr; } +InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( + const char* data, size_t len) { + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromBuffer(data, len); + return model ? new InterpreterWrapper(std::move(model)) : nullptr; +} + } // namespace interpreter_wrapper } // namespace tflite diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index dea71ca879..63bdb30f79 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -37,7 +37,11 @@ namespace interpreter_wrapper { class InterpreterWrapper { public: // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPP(const char* model_path); + static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); + + // SWIG caller takes ownership of pointer. + static InterpreterWrapper* CreateWrapperCPPFromBuffer(const char* data, + size_t len); ~InterpreterWrapper(); bool AllocateTensors(); -- cgit v1.2.3