aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/android
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/android')
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 80e03f2036..1f423a7a5b 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -288,6 +288,22 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
+ public void feed(String inputName, boolean[] src, long... dims) {
+ byte[] b = new byte[src.length];
+
+ for (int i = 0; i < src.length; i++) {
+ b[i] = src[i] ? (byte) 1 : (byte) 0;
+ }
+
+ addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
+ }
+
+ /**
+ * Given a source array with shape {@link dims} and content {@link src}, copy the contents into
+ * the input Tensor with name {@link inputName}. The source array {@link src} must have at least
+ * as many elements as that of the destination Tensor. If {@link src} has more elements than the
+ * destination has capacity, the copy is truncated.
+ */
public void feed(String inputName, float[] src, long... dims) {
addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
}