diff options
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java')
-rw-r--r-- | tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java | 317 |
1 files changed, 317 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java new file mode 100644 index 0000000000..d75136485a --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/DetectorActivity.java @@ -0,0 +1,317 @@ +/* + * Copyright 2016 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.demo; + +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import android.graphics.Paint.Style; +import android.graphics.RectF; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; +import android.os.SystemClock; +import android.os.Trace; +import android.util.Size; +import android.util.TypedValue; +import android.view.Display; +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; +import java.util.Vector; +import org.tensorflow.demo.OverlayView.DrawCallback; +import org.tensorflow.demo.env.BorderedText; +import org.tensorflow.demo.env.ImageUtils; +import org.tensorflow.demo.env.Logger; +import org.tensorflow.demo.tracking.MultiBoxTracker; + +/** + * An activity that uses a TensorFlowMultiboxDetector and ObjectTracker to detect and then track + * objects. + */ +public class DetectorActivity extends CameraActivity implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + + private static final int NUM_LOCATIONS = 784; + private static final int INPUT_SIZE = 224; + private static final int IMAGE_MEAN = 128; + private static final float IMAGE_STD = 128; + private static final String INPUT_NAME = "ResizeBilinear"; + private static final String OUTPUT_NAMES = "output_locations/Reshape,output_scores/Reshape"; + + private static final String MODEL_FILE = "file:///android_asset/multibox_model.pb"; + private static final String LOCATION_FILE = "file:///android_asset/multibox_location_priors.pb"; + + // Minimum detection confidence to track a detection. + private static final float MINIMUM_CONFIDENCE = 0.1f; + + private static final boolean SAVE_PREVIEW_BITMAP = false; + + private static final boolean MAINTAIN_ASPECT = false; + + private static final float TEXT_SIZE_DIP = 18; + + private Integer sensorOrientation; + + private TensorFlowMultiBoxDetector detector; + + private int previewWidth = 0; + private int previewHeight = 0; + private byte[][] yuvBytes; + private int[] rgbBytes = null; + private Bitmap rgbFrameBitmap = null; + private Bitmap croppedBitmap = null; + + private boolean computing = false; + + private long timestamp = 0; + + private Matrix frameToCropTransform; + private Matrix cropToFrameTransform; + + private Bitmap cropCopyBitmap; + + private MultiBoxTracker tracker; + + private byte[] luminance; + + private BorderedText borderedText; + + private long lastProcessingTimeMs; + + @Override + public void onPreviewSizeChosen(final Size size, final int rotation) { + final float textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + borderedText = new BorderedText(textSizePx); + + tracker = new MultiBoxTracker(getResources().getDisplayMetrics()); + + detector = new TensorFlowMultiBoxDetector(); + try { + detector.initializeTensorFlow( + getAssets(), + MODEL_FILE, + LOCATION_FILE, + NUM_LOCATIONS, + INPUT_SIZE, + IMAGE_MEAN, + IMAGE_STD, + INPUT_NAME, + OUTPUT_NAMES); + } catch (final IOException e) { + LOGGER.e(e, "Exception!"); + } + + previewWidth = size.getWidth(); + previewHeight = size.getHeight(); + + final Display display = getWindowManager().getDefaultDisplay(); + final int screenOrientation = display.getRotation(); + + LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation); + + sensorOrientation = rotation + screenOrientation; + + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + rgbBytes = new int[previewWidth * previewHeight]; + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); + + frameToCropTransform = + ImageUtils.getTransformationMatrix( + previewWidth, previewHeight, + INPUT_SIZE, INPUT_SIZE, + sensorOrientation, MAINTAIN_ASPECT); + + cropToFrameTransform = new Matrix(); + frameToCropTransform.invert(cropToFrameTransform); + yuvBytes = new byte[3][]; + + addCallback( + new DrawCallback() { + @Override + public void drawCallback(final Canvas canvas) { + final Bitmap copy = cropCopyBitmap; + + tracker.draw(canvas); + + if (!isDebug()) { + return; + } + + tracker.drawDebug(canvas); + + if (copy != null) { + final Matrix matrix = new Matrix(); + final float scaleFactor = 2; + matrix.postScale(scaleFactor, scaleFactor); + matrix.postTranslate( + canvas.getWidth() - copy.getWidth() * scaleFactor, + canvas.getHeight() - copy.getHeight() * scaleFactor); + canvas.drawBitmap(copy, matrix, new Paint()); + + final Vector<String> lines = new Vector<String>(); + lines.add("Frame: " + previewWidth + "x" + previewHeight); + lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight()); + lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight()); + lines.add("Rotation: " + sensorOrientation); + lines.add("Inference time: " + lastProcessingTimeMs + "ms"); + + int lineNum = 0; + for (final String line : lines) { + borderedText.drawText( + canvas, + 10, + canvas.getHeight() - 10 - borderedText.getTextSize() * lineNum, + line); + ++lineNum; + } + } + } + }); + } + + @Override + public void onImageAvailable(final ImageReader reader) { + Image image = null; + + ++timestamp; + final long currTimestamp = timestamp; + + try { + image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + Trace.beginSection("imageAvailable"); + + final Plane[] planes = image.getPlanes(); + fillBytes(planes, yuvBytes); + + tracker.onFrame( + previewWidth, + previewHeight, + planes[0].getRowStride(), + sensorOrientation, + yuvBytes[0], + timestamp); + + requestRender(); + + // No mutex needed as this method is not reentrant. + if (computing) { + image.close(); + return; + } + computing = true; + + final int yRowStride = planes[0].getRowStride(); + final int uvRowStride = planes[1].getRowStride(); + final int uvPixelStride = planes[1].getPixelStride(); + ImageUtils.convertYUV420ToARGB8888( + yuvBytes[0], + yuvBytes[1], + yuvBytes[2], + rgbBytes, + previewWidth, + previewHeight, + yRowStride, + uvRowStride, + uvPixelStride, + false); + + image.close(); + } catch (final Exception e) { + if (image != null) { + image.close(); + } + LOGGER.e(e, "Exception!"); + Trace.endSection(); + return; + } + + rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight); + final Canvas canvas = new Canvas(croppedBitmap); + canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null); + + // For examining the actual TF input. + if (SAVE_PREVIEW_BITMAP) { + ImageUtils.saveBitmap(croppedBitmap); + } + + if (luminance == null) { + luminance = new byte[yuvBytes[0].length]; + } + System.arraycopy(yuvBytes[0], 0, luminance, 0, luminance.length); + + runInBackground( + new Runnable() { + @Override + public void run() { + final long startTime = SystemClock.uptimeMillis(); + final List<Classifier.Recognition> results = detector.recognizeImage(croppedBitmap); + lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime; + + cropCopyBitmap = Bitmap.createBitmap(croppedBitmap); + final Canvas canvas = new Canvas(cropCopyBitmap); + final Paint paint = new Paint(); + paint.setColor(Color.RED); + paint.setStyle(Style.STROKE); + paint.setStrokeWidth(2.0f); + + final List<Classifier.Recognition> mappedRecognitions = + new LinkedList<Classifier.Recognition>(); + + for (final Classifier.Recognition result : results) { + final RectF location = result.getLocation(); + if (location != null && result.getConfidence() >= MINIMUM_CONFIDENCE) { + canvas.drawRect(location, paint); + + cropToFrameTransform.mapRect(location); + result.setLocation(location); + mappedRecognitions.add(result); + } + } + + tracker.trackResults(mappedRecognitions, luminance, currTimestamp); + + requestRender(); + computing = false; + } + }); + + Trace.endSection(); + } + + @Override + protected int getLayoutId() { + return R.layout.camera_connection_fragment_tracking; + } + + @Override + protected int getDesiredPreviewFrameSize() { + return INPUT_SIZE; + } +} |