diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-07-20 16:27:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 16:36:23 -0700 |
commit | a4b95884f870a040038e530c978239999933acd9 (patch) | |
tree | 0e2e27120184834c4e57c7770695211b2242be49 /tensorflow/contrib/lite/python | |
parent | 9e61678787d329322dd729db92e833c874bdf835 (diff) |
TFLite Python: Make resize_input_tensor accept list/tuple sizes.
PiperOrigin-RevId: 205471451
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/interpreter_test.py | 2 |
3 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 727fbff38e..860aff9e7e 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -20,6 +20,7 @@ py_library( deps = [ "//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper", "//tensorflow/python:util", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index e1981ceae2..3243bddac8 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import sys +import numpy as np from tensorflow.python.util.lazy_loader import LazyLoader # Lazy load since some of the performance benchmark skylark rules @@ -162,6 +163,9 @@ class Interpreter(object): ValueError: If the interpreter could not resize the input tensor. """ self._ensure_safe() + # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size + # parameter. + tensor_size = np.array(tensor_size, dtype=np.int32) self._interpreter.ResizeInputTensor(input_index, tensor_size) def get_output_details(self): @@ -204,7 +208,7 @@ class Interpreter(object): for i in range(10): input().fill(3.) interpreter.invoke() - print("inference %s" % output) + print("inference %s" % output()) Notice how this function avoids making a numpy array directly. This is because it is important to not hold actual numpy views to the data longer diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py index 95fa4b8584..e77d52ca99 100644 --- a/tensorflow/contrib/lite/python/interpreter_test.py +++ b/tensorflow/contrib/lite/python/interpreter_test.py @@ -83,7 +83,7 @@ class InterpreterTest(test_util.TensorFlowTestCase): test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8) expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8) interpreter.resize_tensor_input(input_details[0]['index'], - np.array(test_input.shape, dtype=np.int32)) + test_input.shape) interpreter.allocate_tensors() interpreter.set_tensor(input_details[0]['index'], test_input) interpreter.invoke() |