diff options
Diffstat (limited to 'tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java')
-rw-r--r-- | tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index b1d18d2faf..587f2941e5 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -27,6 +27,7 @@ import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; +import java.nio.LongBuffer; import java.util.ArrayList; import java.util.List; import org.tensorflow.DataType; @@ -232,6 +233,16 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ + public void feed(String inputName, long[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, LongBuffer.wrap(src))); + } + + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ public void feed(String inputName, double[] src, long... dims) { addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src))); } @@ -277,6 +288,17 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ + public void feed(String inputName, LongBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); + } + + /** + * Given a source buffer with shape {@link dims} and content {@link src}, both stored as + * <b>direct</b> and <b>native ordered</b> java.nio buffers, copy the contents into the input + * Tensor with name {@link inputName}. The source buffer {@link src} must have at least as many + * elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ public void feed(String inputName, DoubleBuffer src, long... dims) { addFeed(inputName, Tensor.create(dims, src)); } @@ -315,6 +337,15 @@ public class TensorFlowInferenceInterface { * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. */ + public void fetch(String outputName, long[] dst) { + fetch(outputName, LongBuffer.wrap(dst)); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link + * dst} must have length greater than or equal to that of the source Tensor. This operation will + * not affect dst's content past the source Tensor's size. + */ public void fetch(String outputName, double[] dst) { fetch(outputName, DoubleBuffer.wrap(dst)); } @@ -354,6 +385,16 @@ public class TensorFlowInferenceInterface { * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. */ + public void fetch(String outputName, LongBuffer dst) { + getTensor(outputName).writeTo(dst); + } + + /** + * Read from a Tensor named {@link outputName} and copy the contents into the <b>direct</b> and + * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than + * or equal to that of the source Tensor. This operation will not affect dst's content past the + * source Tensor's size. + */ public void fetch(String outputName, DoubleBuffer dst) { getTensor(outputName).writeTo(dst); } |