aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-20 13:38:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 13:41:44 -0700
commit4a6ab2cb8c2f33ffb6b64d61bd09f006e75982c8 (patch)
treef228b45b30335c6993fe9703fa340f05f37c879b
parent13ae129449cdeb7afbad98bc8a00ad5c82a0ca31 (diff)
Build tflite interpreter from buffer in python interface
PiperOrigin-RevId: 189800400
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py26
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py53
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc9
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h6
4 files changed, 62 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 5b5a7c3199..accdd04671 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -23,19 +23,33 @@ from tensorflow.contrib.lite.python.interpreter_wrapper import tensorflow_wrap_i
class Interpreter(object):
"""Interpreter inferace for TF-Lite Models."""
- def __init__(self, model_path):
+ def __init__(self, model_path=None, model_content=None):
"""Constructor.
Args:
model_path: Path to TF-Lite Flatbuffer file.
+ model_content: Content of model.
Raises:
- ValueError: If the interpreter was unable to open the model.
+ ValueError: If the interpreter was unable to create.
"""
- self._interpreter = (
- interpreter_wrapper.InterpreterWrapper_CreateWrapperCPP(model_path))
- if not self._interpreter:
- raise ValueError('Failed to open {}'.format(model_path))
+ if model_path and not model_content:
+ self._interpreter = (
+ interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromFile(
+ model_path))
+ if not self._interpreter:
+ raise ValueError('Failed to open {}'.format(model_path))
+ elif model_content and not model_path:
+ self._interpreter = (
+ interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
+ model_content, len(model_content)))
+ if not self._interpreter:
+ raise ValueError(
+ 'Failed to create model from {} bytes'.format(len(model_content)))
+ elif not model_path and not model_path:
+ raise ValueError('`model_path` or `model_content` must be specified.')
+ else:
+ raise ValueError('Can\'t both provide `model_path` and `model_content`')
def allocate_tensors(self):
if not self._interpreter.AllocateTensors():
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
index e0215b721c..e85390c56c 100644
--- a/tensorflow/contrib/lite/python/interpreter_test.py
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import io
import numpy as np
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
@@ -29,7 +30,8 @@ class InterpreterTest(test_util.TensorFlowTestCase):
def testFloat(self):
interpreter = interpreter_wrapper.Interpreter(
- resource_loader.get_path_to_datafile('testdata/permute_float.tflite'))
+ model_path=resource_loader.get_path_to_datafile(
+ 'testdata/permute_float.tflite'))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
@@ -53,29 +55,32 @@ class InterpreterTest(test_util.TensorFlowTestCase):
self.assertTrue((expected_output == output_data).all())
def testUint8(self):
- interpreter = interpreter_wrapper.Interpreter(
- resource_loader.get_path_to_datafile('testdata/permute_uint8.tflite'))
- interpreter.allocate_tensors()
-
- input_details = interpreter.get_input_details()
- self.assertEqual(1, len(input_details))
- self.assertEqual('input', input_details[0]['name'])
- self.assertEqual(np.uint8, input_details[0]['dtype'])
- self.assertTrue(([1, 4] == input_details[0]['shape']).all())
-
- output_details = interpreter.get_output_details()
- self.assertEqual(1, len(output_details))
- self.assertEqual('output', output_details[0]['name'])
- self.assertEqual(np.uint8, output_details[0]['dtype'])
- self.assertTrue(([1, 4] == output_details[0]['shape']).all())
-
- test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
- expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
- interpreter.set_tensor(input_details[0]['index'], test_input)
- interpreter.invoke()
-
- output_data = interpreter.get_tensor(output_details[0]['index'])
- self.assertTrue((expected_output == output_data).all())
+ model_path = resource_loader.get_path_to_datafile(
+ 'testdata/permute_uint8.tflite')
+ with io.open(model_path, 'rb') as model_file:
+ data = model_file.read()
+ interpreter = interpreter_wrapper.Interpreter(model_content=data)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.uint8, input_details[0]['dtype'])
+ self.assertTrue(([1, 4] == input_details[0]['shape']).all())
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('output', output_details[0]['name'])
+ self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+
+ test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
+ expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
+ interpreter.set_tensor(input_details[0]['index'], test_input)
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ self.assertTrue((expected_output == output_data).all())
if __name__ == '__main__':
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index f30067de94..14e1190c80 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -302,12 +302,19 @@ PyObject* InterpreterWrapper::GetTensor(int i) const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
-InterpreterWrapper* InterpreterWrapper::CreateWrapperCPP(
+InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
const char* model_path) {
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(model_path);
return model ? new InterpreterWrapper(std::move(model)) : nullptr;
}
+InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
+ const char* data, size_t len) {
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromBuffer(data, len);
+ return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+}
+
} // namespace interpreter_wrapper
} // namespace tflite
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index dea71ca879..63bdb30f79 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -37,7 +37,11 @@ namespace interpreter_wrapper {
class InterpreterWrapper {
public:
// SWIG caller takes ownership of pointer.
- static InterpreterWrapper* CreateWrapperCPP(const char* model_path);
+ static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path);
+
+ // SWIG caller takes ownership of pointer.
+ static InterpreterWrapper* CreateWrapperCPPFromBuffer(const char* data,
+ size_t len);
~InterpreterWrapper();
bool AllocateTensors();