diff options
Diffstat (limited to 'tensorflow/contrib/android')
-rw-r--r-- | tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 395dd6c5d2..80e03f2036 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -31,12 +31,13 @@ import java.nio.IntBuffer; import java.nio.LongBuffer; import java.util.ArrayList; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operation; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.TensorFlow; +import org.tensorflow.Tensors; +import org.tensorflow.types.UInt8; /** * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface @@ -328,7 +329,7 @@ public class TensorFlowInferenceInterface { * destination has capacity, the copy is truncated. */ public void feed(String inputName, byte[] src, long... dims) { - addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src))); + addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src))); } /** @@ -337,7 +338,7 @@ public class TensorFlowInferenceInterface { * a Java {@code String} (which is a sequence of characters). */ public void feedString(String inputName, byte[] src) { - addFeed(inputName, Tensor.create(src)); + addFeed(inputName, Tensors.create(src)); } /** @@ -346,7 +347,7 @@ public class TensorFlowInferenceInterface { * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters). */ public void feedString(String inputName, byte[][] src) { - addFeed(inputName, Tensor.create(src)); + addFeed(inputName, Tensors.create(src)); } // Methods for taking a native Tensor and filling it with src from Java native IO buffers. @@ -403,7 +404,7 @@ public class TensorFlowInferenceInterface { * destination has capacity, the copy is truncated. */ public void feed(String inputName, ByteBuffer src, long... dims) { - addFeed(inputName, Tensor.create(DataType.UINT8, dims, src)); + addFeed(inputName, Tensor.create(UInt8.class, dims, src)); } /** @@ -544,7 +545,7 @@ public class TensorFlowInferenceInterface { "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()); } - private void addFeed(String inputName, Tensor t) { + private void addFeed(String inputName, Tensor<?> t) { // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index]. TensorId tid = TensorId.parse(inputName); runner.feed(tid.name, tid.outputIndex, t); @@ -578,7 +579,7 @@ public class TensorFlowInferenceInterface { } } - private Tensor getTensor(String outputName) { + private Tensor<?> getTensor(String outputName) { int i = 0; for (String n : fetchNames) { if (n.equals(outputName)) { @@ -591,7 +592,7 @@ public class TensorFlowInferenceInterface { } private void closeFeeds() { - for (Tensor t : feedTensors) { + for (Tensor<?> t : feedTensors) { t.close(); } feedTensors.clear(); @@ -599,7 +600,7 @@ public class TensorFlowInferenceInterface { } private void closeFetches() { - for (Tensor t : fetchTensors) { + for (Tensor<?> t : fetchTensors) { t.close(); } fetchTensors.clear(); @@ -614,9 +615,9 @@ public class TensorFlowInferenceInterface { // State reset on every call to run. private Session.Runner runner; private List<String> feedNames = new ArrayList<String>(); - private List<Tensor> feedTensors = new ArrayList<Tensor>(); + private List<Tensor<?>> feedTensors = new ArrayList<Tensor<?>>(); private List<String> fetchNames = new ArrayList<String>(); - private List<Tensor> fetchTensors = new ArrayList<Tensor>(); + private List<Tensor<?>> fetchTensors = new ArrayList<Tensor<?>>(); // Mutable state. private RunStats runStats; |