diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-09 18:18:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-09 18:22:04 -0700 |
commit | 53aabd5cb0ffcc1fd33cbd00eb468dd8d8353df2 (patch) | |
tree | a1fd42cddba6b0b82e2eb8b121f7e2ab31c08204 | |
parent | 22730fd4c633a74e59c03ff76dc92e6ae2d5d020 (diff) |
Update Android Detect demo to use models exported using the Tensorflow Object Detection API. Resolves #6738.
PiperOrigin-RevId: 164802542
9 files changed, 292 insertions, 23 deletions
@@ -52,6 +52,16 @@ new_http_archive( ) new_http_archive( + name = "mobile_ssd", + build_file = "models.BUILD", + sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", + urls = [ + "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", + "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", + ], +) + +new_http_archive( name = "mobile_multibox", build_file = "models.BUILD", sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h index 973ad4544a..030c00cb8e 100644 --- a/tensorflow/core/framework/register_types.h +++ b/tensorflow/core/framework/register_types.h @@ -95,7 +95,7 @@ limitations under the License. #define TF_CALL_resource(m) #define TF_CALL_complex64(m) #define TF_CALL_int64(m) m(::tensorflow::int64) -#define TF_CALL_bool(m) +#define TF_CALL_bool(m) m(bool) #define TF_CALL_qint8(m) m(::tensorflow::qint8) #define TF_CALL_quint8(m) m(::tensorflow::quint8) @@ -122,7 +122,7 @@ limitations under the License. #define TF_CALL_resource(m) #define TF_CALL_complex64(m) #define TF_CALL_int64(m) -#define TF_CALL_bool(m) +#define TF_CALL_bool(m) m(bool) #define TF_CALL_qint8(m) #define TF_CALL_quint8(m) diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD index 71c16e2399..2d3b0911fc 100644 --- a/tensorflow/examples/android/BUILD +++ b/tensorflow/examples/android/BUILD @@ -92,7 +92,7 @@ filegroup( name = "external_assets", srcs = [ "@inception5h//:model_files", - "@mobile_multibox//:model_files", + "@mobile_ssd//:model_files", "@stylize//:model_files", ], ) diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md index 51f6c4a71c..f9881287cd 100644 --- a/tensorflow/examples/android/README.md +++ b/tensorflow/examples/android/README.md @@ -24,12 +24,14 @@ on API >= 14 devices. model to classify camera frames in real-time, displaying the top results in an overlay on the camera image. 2. [TF Detect](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java): - Demonstrates a model based on [Scalable Object Detection - using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to - localize and track people in the camera preview in real-time. + Demonstrates an SSD-Mobilenet model trained using the + [Tensorflow Object Detection API](https://github.com/tensorflow/models/tree/master/object_detection/) + introduced in [Speed/accuracy trade-offs for modern convolutional object detectors](https://arxiv.org/abs/1611.10012) to + localize and track objects (from 80 categories) in the camera preview + in real-time. 3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java): Uses a model based on [A Learned Representation For Artistic - Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview + Style](https://arxiv.org/abs/1610.07629) to restyle the camera preview image to that of a number of different artists. <img src="sample_images/classify1.jpg" width="30%"><img src="sample_images/stylize1.jpg" width="30%"><img src="sample_images/detect1.jpg" width="30%"> @@ -149,7 +151,7 @@ and extract the archives yourself to the `assets` directory in the source tree: ```bash BASE_URL=https://storage.googleapis.com/download.tensorflow.org/models -for MODEL_ZIP in inception5h.zip mobile_multibox_v1a.zip stylize_v1.zip +for MODEL_ZIP in inception5h.zip ssd_mobilenet_v1_android_export.zip stylize_v1.zip do curl -L ${BASE_URL}/${MODEL_ZIP} -o /tmp/${MODEL_ZIP} unzip /tmp/${MODEL_ZIP} -d tensorflow/examples/android/assets/ diff --git a/tensorflow/examples/android/download-models.gradle b/tensorflow/examples/android/download-models.gradle index aca015fa26..f1750ffddb 100644 --- a/tensorflow/examples/android/download-models.gradle +++ b/tensorflow/examples/android/download-models.gradle @@ -10,6 +10,7 @@ // hard coded model files // LINT.IfChange def models = ['inception5h.zip', + 'ssd_mobilenet_v1_android_export.zip', 'mobile_multibox_v1a.zip', 'stylize_v1.zip'] // LINT.ThenChange(//tensorflow/examples/android/BUILD) diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java index acace0eace..b2a22f3963 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -34,6 +34,8 @@ import android.os.Trace; import android.util.Size; import android.util.TypedValue; import android.view.Display; +import android.widget.Toast; +import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Vector; @@ -62,6 +64,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable private static final String MB_LOCATION_FILE = "file:///android_asset/multibox_location_priors.txt"; + private static final int TF_OD_API_INPUT_SIZE = 300; + private static final String TF_OD_API_MODEL_FILE = + "file:///android_asset/ssd_mobilenet_v1_android_export.pb"; + private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt"; + // Configuration values for tiny-yolo-voc. Note that the graph is not included with TensorFlow and // must be manually placed in the assets/ directory by the user. // Graphs and models downloaded from http://pjreddie.com/darknet/yolo/ may be converted e.g. via @@ -73,15 +80,20 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable private static final String YOLO_OUTPUT_NAMES = "output"; private static final int YOLO_BLOCK_SIZE = 32; - // Default to the included multibox model. - private static final boolean USE_YOLO = false; - - private static final int CROP_SIZE = USE_YOLO ? YOLO_INPUT_SIZE : MB_INPUT_SIZE; + // Which detection model to use: by default uses Tensorflow Object Detection API frozen + // checkpoints. Optionally use legacy Multibox (trained using an older version of the API) + // or YOLO. + private enum DetectorMode { + TF_OD_API, MULTIBOX, YOLO; + } + private static final DetectorMode MODE = DetectorMode.TF_OD_API; // Minimum detection confidence to track a detection. - private static final float MINIMUM_CONFIDENCE = USE_YOLO ? 0.25f : 0.1f; + private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.6f; + private static final float MINIMUM_CONFIDENCE_MULTIBOX = 0.1f; + private static final float MINIMUM_CONFIDENCE_YOLO = 0.25f; - private static final boolean MAINTAIN_ASPECT = USE_YOLO; + private static final boolean MAINTAIN_ASPECT = MODE == DetectorMode.YOLO; private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480); @@ -126,8 +138,8 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable tracker = new MultiBoxTracker(this); - - if (USE_YOLO) { + int cropSize = TF_OD_API_INPUT_SIZE; + if (MODE == DetectorMode.YOLO) { detector = TensorFlowYoloDetector.create( getAssets(), @@ -136,7 +148,8 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable YOLO_INPUT_NAME, YOLO_OUTPUT_NAMES, YOLO_BLOCK_SIZE); - } else { + cropSize = YOLO_INPUT_SIZE; + } else if (MODE == DetectorMode.MULTIBOX) { detector = TensorFlowMultiBoxDetector.create( getAssets(), @@ -147,6 +160,20 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable MB_INPUT_NAME, MB_OUTPUT_LOCATIONS_NAME, MB_OUTPUT_SCORES_NAME); + cropSize = MB_INPUT_SIZE; + } else { + try { + detector = TensorFlowObjectDetectionAPIModel.create( + getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE); + cropSize = TF_OD_API_INPUT_SIZE; + } catch (final IOException e) { + LOGGER.e("Exception initializing classifier!", e); + Toast toast = + Toast.makeText( + getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); + toast.show(); + finish(); + } } previewWidth = size.getWidth(); @@ -162,12 +189,12 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); rgbBytes = new int[previewWidth * previewHeight]; rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); - croppedBitmap = Bitmap.createBitmap(CROP_SIZE, CROP_SIZE, Config.ARGB_8888); + croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888); frameToCropTransform = ImageUtils.getTransformationMatrix( previewWidth, previewHeight, - CROP_SIZE, CROP_SIZE, + cropSize, cropSize, sensorOrientation, MAINTAIN_ASPECT); cropToFrameTransform = new Matrix(); @@ -322,12 +349,19 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable paint.setStyle(Style.STROKE); paint.setStrokeWidth(2.0f); + float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; + switch (MODE) { + case TF_OD_API: minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API; break; + case MULTIBOX: minimumConfidence = MINIMUM_CONFIDENCE_MULTIBOX; break; + case YOLO: minimumConfidence = MINIMUM_CONFIDENCE_YOLO; break; + } + final List<Classifier.Recognition> mappedRecognitions = new LinkedList<Classifier.Recognition>(); for (final Classifier.Recognition result : results) { final RectF location = result.getLocation(); - if (location != null && result.getConfidence() >= MINIMUM_CONFIDENCE) { + if (location != null && result.getConfidence() >= minimumConfidence) { canvas.drawRect(location, paint); cropToFrameTransform.mapRect(location); @@ -347,7 +381,7 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable Trace.endSection(); } - protected void processImageRGBbytes(int[] rgbBytes ) {} + protected void processImageRGBbytes(int[] rgbBytes ) {} @Override protected int getLayoutId() { diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java index b4a231ff17..ea837e62d5 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -41,7 +41,7 @@ import org.tensorflow.demo.env.Logger; public class TensorFlowMultiBoxDetector implements Classifier { private static final Logger LOGGER = new Logger(); - // Only return this many results with at least this confidence. + // Only return this many results. private static final int MAX_RESULTS = Integer.MAX_VALUE; // Config values. diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java new file mode 100644 index 0000000000..687318c7ce --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowObjectDetectionAPIModel.java @@ -0,0 +1,219 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.demo; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.RectF; +import android.os.Trace; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; +import java.util.Vector; +import org.tensorflow.Graph; +import org.tensorflow.Operation; +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; +import org.tensorflow.demo.env.Logger; + +/** + * Wrapper for frozen detection models trained using the Tensorflow Object Detection API: + * github.com/tensorflow/models/tree/master/object_detection + */ +public class TensorFlowObjectDetectionAPIModel implements Classifier { + private static final Logger LOGGER = new Logger(); + + // Only return this many results. + private static final int MAX_RESULTS = 100; + + // Config values. + private String inputName; + private int inputSize; + + // Pre-allocated buffers. + private Vector<String> labels = new Vector<String>(); + private int[] intValues; + private byte[] byteValues; + private float[] outputLocations; + private float[] outputScores; + private float[] outputClasses; + private float[] outputNumDetections; + private String[] outputNames; + + private boolean logStats = false; + + private TensorFlowInferenceInterface inferenceInterface; + + /** + * Initializes a native TensorFlow session for classifying images. + * + * @param assetManager The asset manager to be used to load assets. + * @param modelFilename The filepath of the model GraphDef protocol buffer. + * @param labelFilename The filepath of label file for classes. + */ + public static Classifier create( + final AssetManager assetManager, + final String modelFilename, + final String labelFilename, + final int inputSize) throws IOException { + final TensorFlowObjectDetectionAPIModel d = new TensorFlowObjectDetectionAPIModel(); + + InputStream labelsInput = null; + String actualFilename = labelFilename.split("file:///android_asset/")[1]; + labelsInput = assetManager.open(actualFilename); + BufferedReader br = null; + br = new BufferedReader(new InputStreamReader(labelsInput)); + String line; + while ((line = br.readLine()) != null) { + LOGGER.w(line); + d.labels.add(line); + } + br.close(); + + + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); + + final Graph g = d.inferenceInterface.graph(); + + d.inputName = "image_tensor"; + // The inputName node has a shape of [N, H, W, C], where + // N is the batch size + // H = W are the height and width + // C is the number of channels (3 for our purposes - RGB) + final Operation inputOp = g.operation(d.inputName); + if (inputOp == null) { + throw new RuntimeException("Failed to find input Node '" + d.inputName + "'"); + } + d.inputSize = inputSize; + // The outputScoresName node has a shape of [N, NumLocations], where N + // is the batch size. + final Operation outputOp1 = g.operation("detection_scores"); + if (outputOp1 == null) { + throw new RuntimeException("Failed to find output Node 'detection_scores'"); + } + final Operation outputOp2 = g.operation("detection_boxes"); + if (outputOp2 == null) { + throw new RuntimeException("Failed to find output Node 'detection_boxes'"); + } + final Operation outputOp3 = g.operation("detection_classes"); + if (outputOp3 == null) { + throw new RuntimeException("Failed to find output Node 'detection_classes'"); + } + + // Pre-allocate buffers. + d.outputNames = new String[] {"detection_boxes", "detection_scores", + "detection_classes", "num_detections"}; + d.intValues = new int[d.inputSize * d.inputSize]; + d.byteValues = new byte[d.inputSize * d.inputSize * 3]; + d.outputScores = new float[MAX_RESULTS]; + d.outputLocations = new float[MAX_RESULTS * 4]; + d.outputClasses = new float[MAX_RESULTS]; + d.outputNumDetections = new float[1]; + return d; + } + + private TensorFlowObjectDetectionAPIModel() {} + + @Override + public List<Recognition> recognizeImage(final Bitmap bitmap) { + // Log this method so that it can be analyzed with systrace. + Trace.beginSection("recognizeImage"); + + Trace.beginSection("preprocessBitmap"); + // Preprocess the image data from 0-255 int to normalized float based + // on the provided parameters. + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + + for (int i = 0; i < intValues.length; ++i) { + byteValues[i * 3 + 2] = (byte) (intValues[i] & 0xFF); + byteValues[i * 3 + 1] = (byte) ((intValues[i] >> 8) & 0xFF); + byteValues[i * 3 + 0] = (byte) ((intValues[i] >> 16) & 0xFF); + } + Trace.endSection(); // preprocessBitmap + + // Copy the input data into TensorFlow. + Trace.beginSection("feed"); + inferenceInterface.feed(inputName, byteValues, 1, inputSize, inputSize, 3); + Trace.endSection(); + + // Run the inference call. + Trace.beginSection("run"); + inferenceInterface.run(outputNames, logStats); + Trace.endSection(); + + // Copy the output Tensor back into the output array. + Trace.beginSection("fetch"); + outputLocations = new float[MAX_RESULTS * 4]; + outputScores = new float[MAX_RESULTS]; + outputClasses = new float[MAX_RESULTS]; + outputNumDetections = new float[1]; + inferenceInterface.fetch(outputNames[0], outputLocations); + inferenceInterface.fetch(outputNames[1], outputScores); + inferenceInterface.fetch(outputNames[2], outputClasses); + inferenceInterface.fetch(outputNames[3], outputNumDetections); + Trace.endSection(); + + // Find the best detections. + final PriorityQueue<Recognition> pq = + new PriorityQueue<Recognition>( + 1, + new Comparator<Recognition>() { + @Override + public int compare(final Recognition lhs, final Recognition rhs) { + // Intentionally reversed to put high confidence at the head of the queue. + return Float.compare(rhs.getConfidence(), lhs.getConfidence()); + } + }); + + // Scale them back to the input size. + for (int i = 0; i < outputScores.length; ++i) { + final RectF detection = + new RectF( + outputLocations[4 * i + 1] * inputSize, + outputLocations[4 * i] * inputSize, + outputLocations[4 * i + 3] * inputSize, + outputLocations[4 * i + 2] * inputSize); + pq.add( + new Recognition("" + i, labels.get((int) outputClasses[i]), outputScores[i], detection)); + } + + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); + for (int i = 0; i < Math.min(pq.size(), MAX_RESULTS); ++i) { + recognitions.add(pq.poll()); + } + Trace.endSection(); // "recognizeImage" + return recognitions; + } + + @Override + public void enableStatLogging(final boolean logStats) { + this.logStats = logStats; + } + + @Override + public String getStatString() { + return inferenceInterface.getStatString(); + } + + @Override + public void close() { + inferenceInterface.close(); + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java index 91d1f9feb1..aae0a4b62a 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java @@ -59,7 +59,10 @@ public class MultiBoxTracker { private static final float MIN_CORRELATION = 0.3f; private static final int[] COLORS = { - Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA + Color.BLUE, Color.RED, Color.GREEN, Color.YELLOW, Color.CYAN, Color.MAGENTA, Color.WHITE, + Color.parseColor("#55FF55"), Color.parseColor("#FFA500"), Color.parseColor("#FF8888"), + Color.parseColor("#AAAAFF"), Color.parseColor("#FFFFAA"), Color.parseColor("#55AAAA"), + Color.parseColor("#AA33AA"), Color.parseColor("#0D0068") }; private final Queue<Integer> availableColors = new LinkedList<Integer>(); |