diff options
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java')
-rw-r--r-- | tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java | 55 |
1 files changed, 49 insertions, 6 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java index e438956c7d..34a4361626 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -19,10 +19,16 @@ import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.RectF; import android.os.Trace; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.PriorityQueue; +import java.util.StringTokenizer; import org.tensorflow.contrib.android.TensorFlowInferenceInterface; import org.tensorflow.demo.env.Logger; @@ -80,7 +86,7 @@ public class TensorFlowMultiBoxDetector implements Classifier { final float imageStd, final String inputName, final String outputName) { - TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector(); + final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector(); d.inputName = inputName; d.inputSize = inputSize; d.imageMean = imageMean; @@ -89,7 +95,11 @@ public class TensorFlowMultiBoxDetector implements Classifier { d.boxPriors = new float[numLocations * 8]; - d.loadCoderOptions(assetManager, locationFilename, d.boxPriors); + try { + d.loadCoderOptions(assetManager, locationFilename, d.boxPriors); + } catch (final IOException e) { + throw new RuntimeException("Error initializing box priors from " + locationFilename); + } // Pre-allocate buffers. d.outputNames = outputName.split(","); @@ -110,9 +120,42 @@ public class TensorFlowMultiBoxDetector implements Classifier { private TensorFlowMultiBoxDetector() {} - // Load BoxCoderOptions from native code. - private native void loadCoderOptions( - AssetManager assetManager, String locationFilename, float[] boxPriors); + private void loadCoderOptions( + final AssetManager assetManager, final String locationFilename, final float[] boxPriors) + throws IOException { + // Try to be intelligent about opening from assets or sdcard depending on prefix. + final String assetPrefix = "file:///android_asset/"; + InputStream is; + if (locationFilename.startsWith(assetPrefix)) { + is = assetManager.open(locationFilename.split(assetPrefix)[1]); + } else { + is = new FileInputStream(locationFilename); + } + + // Read values. Number of values per line doesn't matter, as long as they are separated + // by commas and/or whitespace, and there are exactly numLocations * 8 values total. + // Values are in the order mean, std for each consecutive corner of each box, for a total of 8 + // per location. + final BufferedReader reader = new BufferedReader(new InputStreamReader(is)); + int priorIndex = 0; + String line; + while ((line = reader.readLine()) != null) { + final StringTokenizer st = new StringTokenizer(line, ", "); + while (st.hasMoreTokens()) { + final String token = st.nextToken(); + try { + final float number = Float.parseFloat(token); + boxPriors[priorIndex++] = number; + } catch (final NumberFormatException e) { + // Silently ignore. + } + } + } + if (priorIndex != boxPriors.length) { + throw new RuntimeException( + "BoxPrior length mismatch: " + priorIndex + " vs " + boxPriors.length); + } + } private float[] decodeLocationsEncoding(final float[] locationEncoding) { final float[] locations = new float[locationEncoding.length]; @@ -216,7 +259,7 @@ public class TensorFlowMultiBoxDetector implements Classifier { } @Override - public void enableStatLogging(boolean debug) { + public void enableStatLogging(final boolean debug) { inferenceInterface.enableStatLogging(debug); } |