diff options
Diffstat (limited to 'tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java')
-rw-r--r-- | tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java | 29 |
1 files changed, 10 insertions, 19 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index fd1f0ffa68..7002f82677 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -135,7 +135,8 @@ public final class Interpreter implements AutoCloseable { * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large * input data. When {@link ByteBuffer} is used, its content should remain unchanged until * model inference is done. - * @param output a multidimensional array of output data. + * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive + * types including int, float, long, and byte. */ public void run(@NonNull Object input, @NonNull Object output) { Object[] inputs = {input}; @@ -155,28 +156,16 @@ public final class Interpreter implements AutoCloseable { * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred * way to pass large input data. When {@link ByteBuffer} is used, its content should remain * unchanged until model inference is done. - * @param outputs a map mapping output indices to multidimensional arrays of output data. It only - * needs to keep entries for the outputs to be used. + * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link + * ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep + * entries for the outputs to be used. */ public void runForMultipleInputsOutputs( @NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) { if (wrapper == null) { throw new IllegalStateException("Internal error: The Interpreter has already been closed."); } - Tensor[] tensors = wrapper.run(inputs); - if (outputs == null || tensors == null || outputs.size() > tensors.length) { - throw new IllegalArgumentException("Output error: Outputs do not match with model outputs."); - } - final int size = tensors.length; - for (Integer idx : outputs.keySet()) { - if (idx == null || idx < 0 || idx >= size) { - throw new IllegalArgumentException( - String.format( - "Output error: Invalid index of output %d (should be in range [0, %d))", - idx, size)); - } - tensors[idx].copyTo(outputs.get(idx)); - } + wrapper.run(inputs, outputs); } /** @@ -249,8 +238,10 @@ public final class Interpreter implements AutoCloseable { /** Release resources associated with the {@code Interpreter}. */ @Override public void close() { - wrapper.close(); - wrapper = null; + if (wrapper != null) { + wrapper.close(); + wrapper = null; + } } @Override |