aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java')
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java55
1 files changed, 49 insertions, 6 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
index e438956c7d..34a4361626 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
@@ -19,10 +19,16 @@ import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.RectF;
import android.os.Trace;
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.PriorityQueue;
+import java.util.StringTokenizer;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import org.tensorflow.demo.env.Logger;
@@ -80,7 +86,7 @@ public class TensorFlowMultiBoxDetector implements Classifier {
final float imageStd,
final String inputName,
final String outputName) {
- TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
+ final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
d.inputName = inputName;
d.inputSize = inputSize;
d.imageMean = imageMean;
@@ -89,7 +95,11 @@ public class TensorFlowMultiBoxDetector implements Classifier {
d.boxPriors = new float[numLocations * 8];
- d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
+ try {
+ d.loadCoderOptions(assetManager, locationFilename, d.boxPriors);
+ } catch (final IOException e) {
+ throw new RuntimeException("Error initializing box priors from " + locationFilename);
+ }
// Pre-allocate buffers.
d.outputNames = outputName.split(",");
@@ -110,9 +120,42 @@ public class TensorFlowMultiBoxDetector implements Classifier {
private TensorFlowMultiBoxDetector() {}
- // Load BoxCoderOptions from native code.
- private native void loadCoderOptions(
- AssetManager assetManager, String locationFilename, float[] boxPriors);
+ private void loadCoderOptions(
+ final AssetManager assetManager, final String locationFilename, final float[] boxPriors)
+ throws IOException {
+ // Try to be intelligent about opening from assets or sdcard depending on prefix.
+ final String assetPrefix = "file:///android_asset/";
+ InputStream is;
+ if (locationFilename.startsWith(assetPrefix)) {
+ is = assetManager.open(locationFilename.split(assetPrefix)[1]);
+ } else {
+ is = new FileInputStream(locationFilename);
+ }
+
+ // Read values. Number of values per line doesn't matter, as long as they are separated
+ // by commas and/or whitespace, and there are exactly numLocations * 8 values total.
+ // Values are in the order mean, std for each consecutive corner of each box, for a total of 8
+ // per location.
+ final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
+ int priorIndex = 0;
+ String line;
+ while ((line = reader.readLine()) != null) {
+ final StringTokenizer st = new StringTokenizer(line, ", ");
+ while (st.hasMoreTokens()) {
+ final String token = st.nextToken();
+ try {
+ final float number = Float.parseFloat(token);
+ boxPriors[priorIndex++] = number;
+ } catch (final NumberFormatException e) {
+ // Silently ignore.
+ }
+ }
+ }
+ if (priorIndex != boxPriors.length) {
+ throw new RuntimeException(
+ "BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length);
+ }
+ }
private float[] decodeLocationsEncoding(final float[] locationEncoding) {
final float[] locations = new float[locationEncoding.length];
@@ -216,7 +259,7 @@ public class TensorFlowMultiBoxDetector implements Classifier {
}
@Override
- public void enableStatLogging(boolean debug) {
+ public void enableStatLogging(final boolean debug) {
inferenceInterface.enableStatLogging(debug);
}