aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/android
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/android')
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java23
1 files changed, 12 insertions, 11 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 395dd6c5d2..80e03f2036 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -31,12 +31,13 @@ import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.List;
-import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
+import org.tensorflow.Tensors;
+import org.tensorflow.types.UInt8;
/**
* Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface
@@ -328,7 +329,7 @@ public class TensorFlowInferenceInterface {
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, byte[] src, long... dims) {
- addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src)));
+ addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src)));
}
/**
@@ -337,7 +338,7 @@ public class TensorFlowInferenceInterface {
* a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[] src) {
- addFeed(inputName, Tensor.create(src));
+ addFeed(inputName, Tensors.create(src));
}
/**
@@ -346,7 +347,7 @@ public class TensorFlowInferenceInterface {
* arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters).
*/
public void feedString(String inputName, byte[][] src) {
- addFeed(inputName, Tensor.create(src));
+ addFeed(inputName, Tensors.create(src));
}
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
@@ -403,7 +404,7 @@ public class TensorFlowInferenceInterface {
* destination has capacity, the copy is truncated.
*/
public void feed(String inputName, ByteBuffer src, long... dims) {
- addFeed(inputName, Tensor.create(DataType.UINT8, dims, src));
+ addFeed(inputName, Tensor.create(UInt8.class, dims, src));
}
/**
@@ -544,7 +545,7 @@ public class TensorFlowInferenceInterface {
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
- private void addFeed(String inputName, Tensor t) {
+ private void addFeed(String inputName, Tensor<?> t) {
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
TensorId tid = TensorId.parse(inputName);
runner.feed(tid.name, tid.outputIndex, t);
@@ -578,7 +579,7 @@ public class TensorFlowInferenceInterface {
}
}
- private Tensor getTensor(String outputName) {
+ private Tensor<?> getTensor(String outputName) {
int i = 0;
for (String n : fetchNames) {
if (n.equals(outputName)) {
@@ -591,7 +592,7 @@ public class TensorFlowInferenceInterface {
}
private void closeFeeds() {
- for (Tensor t : feedTensors) {
+ for (Tensor<?> t : feedTensors) {
t.close();
}
feedTensors.clear();
@@ -599,7 +600,7 @@ public class TensorFlowInferenceInterface {
}
private void closeFetches() {
- for (Tensor t : fetchTensors) {
+ for (Tensor<?> t : fetchTensors) {
t.close();
}
fetchTensors.clear();
@@ -614,9 +615,9 @@ public class TensorFlowInferenceInterface {
// State reset on every call to run.
private Session.Runner runner;
private List<String> feedNames = new ArrayList<String>();
- private List<Tensor> feedTensors = new ArrayList<Tensor>();
+ private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>();
private List<String> fetchNames = new ArrayList<String>();
- private List<Tensor> fetchTensors = new ArrayList<Tensor>();
+ private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>();
// Mutable state.
private RunStats runStats;