aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java29
1 files changed, 10 insertions, 19 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index fd1f0ffa68..7002f82677 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -135,7 +135,8 @@ public final class Interpreter implements AutoCloseable {
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
* input data. When {@link ByteBuffer} is used, its content should remain unchanged until
* model inference is done.
- * @param output a multidimensional array of output data.
+ * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
+ * types including int, float, long, and byte.
*/
public void run(@NonNull Object input, @NonNull Object output) {
Object[] inputs = {input};
@@ -155,28 +156,16 @@ public final class Interpreter implements AutoCloseable {
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
* way to pass large input data. When {@link ByteBuffer} is used, its content should remain
* unchanged until model inference is done.
- * @param outputs a map mapping output indices to multidimensional arrays of output data. It only
- * needs to keep entries for the outputs to be used.
+ * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
+ * ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep
+ * entries for the outputs to be used.
*/
public void runForMultipleInputsOutputs(
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
if (wrapper == null) {
throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
}
- Tensor[] tensors = wrapper.run(inputs);
- if (outputs == null || tensors == null || outputs.size() > tensors.length) {
- throw new IllegalArgumentException("Output error: Outputs do not match with model outputs.");
- }
- final int size = tensors.length;
- for (Integer idx : outputs.keySet()) {
- if (idx == null || idx < 0 || idx >= size) {
- throw new IllegalArgumentException(
- String.format(
- "Output error: Invalid index of output %d (should be in range [0, %d))",
- idx, size));
- }
- tensors[idx].copyTo(outputs.get(idx));
- }
+ wrapper.run(inputs, outputs);
}
/**
@@ -249,8 +238,10 @@ public final class Interpreter implements AutoCloseable {
/** Release resources associated with the {@code Interpreter}. */
@Override
public void close() {
- wrapper.close();
- wrapper = null;
+ if (wrapper != null) {
+ wrapper.close();
+ wrapper = null;
+ }
}
@Override