aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE10
-rw-r--r--tensorflow/core/framework/register_types.h4
-rw-r--r--tensorflow/examples/android/BUILD2
-rw-r--r--tensorflow/examples/android/README.md12
-rw-r--r--tensorflow/examples/android/download-models.gradle1
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java60
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java2
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java219
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java5
9 files changed, 292 insertions, 23 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 6b5d24560c..959587387e 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -52,6 +52,16 @@ new_http_archive(
)
new_http_archive(
+ name = "mobile_ssd",
+ build_file = "models.BUILD",
+ sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8",
+ urls = [
+ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
+ "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip",
+ ],
+)
+
+new_http_archive(
name = "mobile_multibox",
build_file = "models.BUILD",
sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96",
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index 973ad4544a..030c00cb8e 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -95,7 +95,7 @@ limitations under the License.
#define TF_CALL_resource(m)
#define TF_CALL_complex64(m)
#define TF_CALL_int64(m) m(::tensorflow::int64)
-#define TF_CALL_bool(m)
+#define TF_CALL_bool(m) m(bool)
#define TF_CALL_qint8(m) m(::tensorflow::qint8)
#define TF_CALL_quint8(m) m(::tensorflow::quint8)
@@ -122,7 +122,7 @@ limitations under the License.
#define TF_CALL_resource(m)
#define TF_CALL_complex64(m)
#define TF_CALL_int64(m)
-#define TF_CALL_bool(m)
+#define TF_CALL_bool(m) m(bool)
#define TF_CALL_qint8(m)
#define TF_CALL_quint8(m)
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 71c16e2399..2d3b0911fc 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -92,7 +92,7 @@ filegroup(
name = "external_assets",
srcs = [
"@inception5h//:model_files",
- "@mobile_multibox//:model_files",
+ "@mobile_ssd//:model_files",
"@stylize//:model_files",
],
)
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index 51f6c4a71c..f9881287cd 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -24,12 +24,14 @@ on API >= 14 devices.
model to classify camera frames in real-time, displaying the top results
in an overlay on the camera image.
2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java):
- Demonstrates a model based on [Scalable Object Detection
- using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to
- localize and track people in the camera preview in real-time.
+ Demonstrates an SSD-Mobilenet model trained using the
+ [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/object_detection/)
+ introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to
+ localize and track objects (from 80 categories) in the camera preview
+ in real-time.
3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java):
Uses a model based on [A Learned Representation For Artistic
- Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview
+ Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview
image to that of a number of different artists.
<img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%">
@@ -149,7 +151,7 @@ and extract the archives yourself to the `assets` directory in the source tree:
```bash
BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models
-for MODEL_ZIP in inception5h.zip mobile_multibox_v1a.zip stylize_v1.zip
+for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip
do
curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP}
unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/
diff --git a/tensorflow/examples/android/download-models.gradle b/tensorflow/examples/android/download-models.gradle
index aca015fa26..f1750ffddb 100644
--- a/tensorflow/examples/android/download-models.gradle
+++ b/tensorflow/examples/android/download-models.gradle
@@ -10,6 +10,7 @@
// hard coded model files
// LINT.IfChange
def models = ['inception5h.zip',
+ 'ssd_mobilenet_v1_android_export.zip',
'mobile_multibox_v1a.zip',
'stylize_v1.zip']
// LINT.ThenChange(//tensorflow/examples/android/BUILD)
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
index acace0eace..b2a22f3963 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java
@@ -34,6 +34,8 @@ import android.os.Trace;
import android.util.Size;
import android.util.TypedValue;
import android.view.Display;
+import android.widget.Toast;
+import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
import java.util.Vector;
@@ -62,6 +64,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
private static final String MB_LOCATION_FILE =
"file:///android_asset/multibox_location_priors.txt";
+ private static final int TF_OD_API_INPUT_SIZE = 300;
+ private static final String TF_OD_API_MODEL_FILE =
+ "file:///android_asset/ssd_mobilenet_v1_android_export.pb";
+ private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";
+
// Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and
// must be manually placed in the assets/ directory by the user.
// Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via
@@ -73,15 +80,20 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
private static final String YOLO_OUTPUT_NAMES = "output";
private static final int YOLO_BLOCK_SIZE = 32;
- // Default to the included multibox model.
- private static final boolean USE_YOLO = false;
-
- private static final int CROP_SIZE = USE_YOLO ? YOLO_INPUT_SIZE : MB_INPUT_SIZE;
+ // Which detection model to use: by default uses Tensorflow Object Detection API frozen
+ // checkpoints. Optionally use legacy Multibox (trained using an older version of the API)
+ // or YOLO.
+ private enum DetectorMode {
+ TF_OD_API, MULTIBOX, YOLO;
+ }
+ private static final DetectorMode MODE = DetectorMode.TF_OD_API;
// Minimum detection confidence to track a detection.
- private static final float MINIMUM_CONFIDENCE = USE_YOLO ? 0.25f : 0.1f;
+ private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f;
+ private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f;
+ private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f;
- private static final boolean MAINTAIN_ASPECT = USE_YOLO;
+ private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO;
private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
@@ -126,8 +138,8 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
tracker = new MultiBoxTracker(this);
-
- if (USE_YOLO) {
+ int cropSize = TF_OD_API_INPUT_SIZE;
+ if (MODE == DetectorMode.YOLO) {
detector =
TensorFlowYoloDetector.create(
getAssets(),
@@ -136,7 +148,8 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
YOLO_INPUT_NAME,
YOLO_OUTPUT_NAMES,
YOLO_BLOCK_SIZE);
- } else {
+ cropSize = YOLO_INPUT_SIZE;
+ } else if (MODE == DetectorMode.MULTIBOX) {
detector =
TensorFlowMultiBoxDetector.create(
getAssets(),
@@ -147,6 +160,20 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
MB_INPUT_NAME,
MB_OUTPUT_LOCATIONS_NAME,
MB_OUTPUT_SCORES_NAME);
+ cropSize = MB_INPUT_SIZE;
+ } else {
+ try {
+ detector = TensorFlowObjectDetectionAPIModel.create(
+ getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
+ cropSize = TF_OD_API_INPUT_SIZE;
+ } catch (final IOException e) {
+ LOGGER.e("Exception initializing classifier!", e);
+ Toast toast =
+ Toast.makeText(
+ getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
+ toast.show();
+ finish();
+ }
}
previewWidth = size.getWidth();
@@ -162,12 +189,12 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
rgbBytes = new int[previewWidth * previewHeight];
rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
- croppedBitmap = Bitmap.createBitmap(CROP_SIZE, CROP_SIZE, Config.ARGB_8888);
+ croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);
frameToCropTransform =
ImageUtils.getTransformationMatrix(
previewWidth, previewHeight,
- CROP_SIZE, CROP_SIZE,
+ cropSize, cropSize,
sensorOrientation, MAINTAIN_ASPECT);
cropToFrameTransform = new Matrix();
@@ -322,12 +349,19 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
paint.setStyle(Style.STROKE);
paint.setStrokeWidth(2.0f);
+ float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
+ switch (MODE) {
+ case TF_OD_API: minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; break;
+ case MULTIBOX: minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX; break;
+ case YOLO: minimumConfidence = MINIMUM_CONFIDENCE_YOLO; break;
+ }
+
final List<Classifier.Recognition> mappedRecognitions =
new LinkedList<Classifier.Recognition>();
for (final Classifier.Recognition result : results) {
final RectF location = result.getLocation();
- if (location != null && result.getConfidence() >= MINIMUM_CONFIDENCE) {
+ if (location != null && result.getConfidence() >= minimumConfidence) {
canvas.drawRect(location, paint);
cropToFrameTransform.mapRect(location);
@@ -347,7 +381,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
Trace.endSection();
}
- protected void processImageRGBbytes(int[] rgbBytes ) {}
+ protected void processImageRGBbytes(int[] rgbBytes ) {}
@Override
protected int getLayoutId() {
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
index b4a231ff17..ea837e62d5 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
@@ -41,7 +41,7 @@ import org.tensorflow.demo.env.Logger;
public class TensorFlowMultiBoxDetector implements Classifier {
private static final Logger LOGGER = new Logger();
- // Only return this many results with at least this confidence.
+ // Only return this many results.
private static final int MAX_RESULTS = Integer.MAX_VALUE;
// Config values.
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java
new file mode 100644
index 0000000000..687318c7ce
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java
@@ -0,0 +1,219 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.demo;
+
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+import android.graphics.RectF;
+import android.os.Trace;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.PriorityQueue;
+import java.util.Vector;
+import org.tensorflow.Graph;
+import org.tensorflow.Operation;
+import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
+import org.tensorflow.demo.env.Logger;
+
+/**
+ * Wrapper for frozen detection models trained using the Tensorflow Object Detection API:
+ * github.com/tensorflow/models/tree/master/object_detection
+ */
+public class TensorFlowObjectDetectionAPIModel implements Classifier {
+ private static final Logger LOGGER = new Logger();
+
+ // Only return this many results.
+ private static final int MAX_RESULTS = 100;
+
+ // Config values.
+ private String inputName;
+ private int inputSize;
+
+ // Pre-allocated buffers.
+ private Vector<String> labels = new Vector<String>();
+ private int[] intValues;
+ private byte[] byteValues;
+ private float[] outputLocations;
+ private float[] outputScores;
+ private float[] outputClasses;
+ private float[] outputNumDetections;
+ private String[] outputNames;
+
+ private boolean logStats = false;
+
+ private TensorFlowInferenceInterface inferenceInterface;
+
+ /**
+ * Initializes a native TensorFlow session for classifying images.
+ *
+ * @param assetManager The asset manager to be used to load assets.
+ * @param modelFilename The filepath of the model GraphDef protocol buffer.
+ * @param labelFilename The filepath of label file for classes.
+ */
+ public static Classifier create(
+ final AssetManager assetManager,
+ final String modelFilename,
+ final String labelFilename,
+ final int inputSize) throws IOException {
+ final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel();
+
+ InputStream labelsInput = null;
+ String actualFilename = labelFilename.split("file:///android_asset/")[1];
+ labelsInput = assetManager.open(actualFilename);
+ BufferedReader br = null;
+ br = new BufferedReader(new InputStreamReader(labelsInput));
+ String line;
+ while ((line = br.readLine()) != null) {
+ LOGGER.w(line);
+ d.labels.add(line);
+ }
+ br.close();
+
+
+ d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
+
+ final Graph g = d.inferenceInterface.graph();
+
+ d.inputName = "image_tensor";
+ // The inputName node has a shape of [N, H, W, C], where
+ // N is the batch size
+ // H = W are the height and width
+ // C is the number of channels (3 for our purposes - RGB)
+ final Operation inputOp = g.operation(d.inputName);
+ if (inputOp == null) {
+ throw new RuntimeException("Failed to find input Node '" + d.inputName + "'");
+ }
+ d.inputSize = inputSize;
+ // The outputScoresName node has a shape of [N, NumLocations], where N
+ // is the batch size.
+ final Operation outputOp1 = g.operation("detection_scores");
+ if (outputOp1 == null) {
+ throw new RuntimeException("Failed to find output Node 'detection_scores'");
+ }
+ final Operation outputOp2 = g.operation("detection_boxes");
+ if (outputOp2 == null) {
+ throw new RuntimeException("Failed to find output Node 'detection_boxes'");
+ }
+ final Operation outputOp3 = g.operation("detection_classes");
+ if (outputOp3 == null) {
+ throw new RuntimeException("Failed to find output Node 'detection_classes'");
+ }
+
+ // Pre-allocate buffers.
+ d.outputNames = new String[] {"detection_boxes", "detection_scores",
+ "detection_classes", "num_detections"};
+ d.intValues = new int[d.inputSize * d.inputSize];
+ d.byteValues = new byte[d.inputSize * d.inputSize * 3];
+ d.outputScores = new float[MAX_RESULTS];
+ d.outputLocations = new float[MAX_RESULTS * 4];
+ d.outputClasses = new float[MAX_RESULTS];
+ d.outputNumDetections = new float[1];
+ return d;
+ }
+
+ private TensorFlowObjectDetectionAPIModel() {}
+
+ @Override
+ public List<Recognition> recognizeImage(final Bitmap bitmap) {
+ // Log this method so that it can be analyzed with systrace.
+ Trace.beginSection("recognizeImage");
+
+ Trace.beginSection("preprocessBitmap");
+ // Preprocess the image data from 0-255 int to normalized float based
+ // on the provided parameters.
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+
+ for (int i = 0; i < intValues.length; ++i) {
+ byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF);
+ byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF);
+ byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF);
+ }
+ Trace.endSection(); // preprocessBitmap
+
+ // Copy the input data into TensorFlow.
+ Trace.beginSection("feed");
+ inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3);
+ Trace.endSection();
+
+ // Run the inference call.
+ Trace.beginSection("run");
+ inferenceInterface.run(outputNames, logStats);
+ Trace.endSection();
+
+ // Copy the output Tensor back into the output array.
+ Trace.beginSection("fetch");
+ outputLocations = new float[MAX_RESULTS * 4];
+ outputScores = new float[MAX_RESULTS];
+ outputClasses = new float[MAX_RESULTS];
+ outputNumDetections = new float[1];
+ inferenceInterface.fetch(outputNames[0], outputLocations);
+ inferenceInterface.fetch(outputNames[1], outputScores);
+ inferenceInterface.fetch(outputNames[2], outputClasses);
+ inferenceInterface.fetch(outputNames[3], outputNumDetections);
+ Trace.endSection();
+
+ // Find the best detections.
+ final PriorityQueue<Recognition> pq =
+ new PriorityQueue<Recognition>(
+ 1,
+ new Comparator<Recognition>() {
+ @Override
+ public int compare(final Recognition lhs, final Recognition rhs) {
+ // Intentionally reversed to put high confidence at the head of the queue.
+ return Float.compare(rhs.getConfidence(), lhs.getConfidence());
+ }
+ });
+
+ // Scale them back to the input size.
+ for (int i = 0; i < outputScores.length; ++i) {
+ final RectF detection =
+ new RectF(
+ outputLocations[4 * i + 1] * inputSize,
+ outputLocations[4 * i] * inputSize,
+ outputLocations[4 * i + 3] * inputSize,
+ outputLocations[4 * i + 2] * inputSize);
+ pq.add(
+ new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection));
+ }
+
+ final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
+ for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) {
+ recognitions.add(pq.poll());
+ }
+ Trace.endSection(); // "recognizeImage"
+ return recognitions;
+ }
+
+ @Override
+ public void enableStatLogging(final boolean logStats) {
+ this.logStats = logStats;
+ }
+
+ @Override
+ public String getStatString() {
+ return inferenceInterface.getStatString();
+ }
+
+ @Override
+ public void close() {
+ inferenceInterface.close();
+ }
+}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
index 91d1f9feb1..aae0a4b62a 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
@@ -59,7 +59,10 @@ public class MultiBoxTracker {
private static final float MIN_CORRELATION = 0.3f;
private static final int[] COLORS = {
- Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA
+ Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE,
+ Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"),
+ Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"),
+ Color.parseColor("#AA33AA"), Color.parseColor("#0D0068")
};
private final Queue<Integer> availableColors = new LinkedList<Integer>();