aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-07-20 16:27:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 16:36:23 -0700
commita4b95884f870a040038e530c978239999933acd9 (patch)
tree0e2e27120184834c4e57c7770695211b2242be49 /tensorflow/contrib/lite/python
parent9e61678787d329322dd729db92e833c874bdf835 (diff)
TFLite Python: Make resize_input_tensor accept list/tuple sizes.
PiperOrigin-RevId: 205471451
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/BUILD1
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py6
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py2
3 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 727fbff38e..860aff9e7e 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -20,6 +20,7 @@ py_library(
deps = [
"//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
"//tensorflow/python:util",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index e1981ceae2..3243bddac8 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import sys
+import numpy as np
from tensorflow.python.util.lazy_loader import LazyLoader
# Lazy load since some of the performance benchmark skylark rules
@@ -162,6 +163,9 @@ class Interpreter(object):
ValueError: If the interpreter could not resize the input tensor.
"""
self._ensure_safe()
+ # `ResizeInputTensor` now only accepts int32 numpy array as `tensor_size
+ # parameter.
+ tensor_size = np.array(tensor_size, dtype=np.int32)
self._interpreter.ResizeInputTensor(input_index, tensor_size)
def get_output_details(self):
@@ -204,7 +208,7 @@ class Interpreter(object):
for i in range(10):
input().fill(3.)
interpreter.invoke()
- print("inference %s" % output)
+ print("inference %s" % output())
Notice how this function avoids making a numpy array directly. This is
because it is important to not hold actual numpy views to the data longer
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
index 95fa4b8584..e77d52ca99 100644
--- a/tensorflow/contrib/lite/python/interpreter_test.py
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -83,7 +83,7 @@ class InterpreterTest(test_util.TensorFlowTestCase):
test_input = np.array([[1, 2, 3, 4]], dtype=np.uint8)
expected_output = np.array([[4, 3, 2, 1]], dtype=np.uint8)
interpreter.resize_tensor_input(input_details[0]['index'],
- np.array(test_input.shape, dtype=np.int32))
+ test_input.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], test_input)
interpreter.invoke()