diff options
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java')
-rw-r--r-- | tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java new file mode 100644 index 0000000000..940fbc6771 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java @@ -0,0 +1,147 @@ +package org.tensorflow.demo; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Canvas; +import android.graphics.Matrix; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; + +import junit.framework.Assert; + +import org.tensorflow.demo.env.ImageUtils; +import org.tensorflow.demo.env.Logger; + +import java.nio.ByteBuffer; +import java.util.List; + +/** + * Class that takes in preview frames and converts the image to Bitmaps to process with Tensorflow. + */ +public class TensorflowImageListener implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + + private static final boolean SAVE_PREVIEW_BITMAP = false; + + private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb"; + private static final String LABEL_FILE = + "file:///android_asset/imagenet_comp_graph_label_strings.txt"; + + private static final int NUM_CLASSES = 1001; + private static final int INPUT_SIZE = 224; + private static final int IMAGE_MEAN = 117; + + // TODO(andrewharp): Get orientation programatically. + private final int screenRotation = 90; + + private final TensorflowClassifier tensorflow = new TensorflowClassifier(); + + private int previewWidth = 0; + private int previewHeight = 0; + private byte[] yuvBytes = null; + private int[] rgbBytes = null; + private Bitmap rgbFrameBitmap = null; + private Bitmap croppedBitmap = null; + + private RecognitionScoreView scoreView; + + public void initialize(final AssetManager assetManager, final RecognitionScoreView scoreView) { + tensorflow.initializeTensorflow( + assetManager, MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN); + this.scoreView = scoreView; + } + + private void drawResizedBitmap(final Bitmap src, final Bitmap dst) { + Assert.assertEquals(dst.getWidth(), dst.getHeight()); + final float minDim = Math.min(src.getWidth(), src.getHeight()); + + final Matrix matrix = new Matrix(); + + // We only want the center square out of the original rectangle. + final float translateX = -Math.max(0, (src.getWidth() - minDim) / 2); + final float translateY = -Math.max(0, (src.getHeight() - minDim) / 2); + matrix.preTranslate(translateX, translateY); + + final float scaleFactor = dst.getHeight() / minDim; + matrix.postScale(scaleFactor, scaleFactor); + + // Rotate around the center if necessary. + if (screenRotation != 0) { + matrix.postTranslate(-dst.getWidth() / 2.0f, -dst.getHeight() / 2.0f); + matrix.postRotate(screenRotation); + matrix.postTranslate(dst.getWidth() / 2.0f, dst.getHeight() / 2.0f); + } + + final Canvas canvas = new Canvas(dst); + canvas.drawBitmap(src, matrix, null); + } + + @Override + public void onImageAvailable(final ImageReader reader) { + Image image = null; + try { + image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + // Initialize the storage bitmaps once when the resolution is known. + if (previewWidth != image.getWidth() || previewHeight != image.getHeight()) { + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + previewWidth = image.getWidth(); + previewHeight = image.getHeight(); + rgbBytes = new int[previewWidth * previewHeight]; + yuvBytes = new byte[ImageUtils.getYUVByteSize(previewWidth, previewHeight)]; + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); + } + + final Plane[] planes = image.getPlanes(); + int position = 0; + + // Copy the bytes from the Image into a buffer for easier conversion to RGB. + // TODO(andrewharp): It may not be correct to do it this way. + final int[] planeOrder = {0, 2}; + for (int i = 0; i < planeOrder.length; ++i) { + final Plane plane = planes[planeOrder[i]]; + final ByteBuffer buffer = plane.getBuffer(); + + buffer.rewind(); + final int readAmount = buffer.remaining(); + + buffer.get(yuvBytes, position, readAmount); + position += readAmount; + } + + image.close(); + + ImageUtils.convertYUV420SPToARGB8888(yuvBytes, rgbBytes, previewWidth, previewHeight, false); + } catch (final Exception e) { + if (image != null) { + image.close(); + } + LOGGER.e(e, "Exception!"); + return; + } + + rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight); + drawResizedBitmap(rgbFrameBitmap, croppedBitmap); + + // For examining the actual TF input. + if (SAVE_PREVIEW_BITMAP) { + ImageUtils.saveBitmap(croppedBitmap); + } + + final List<Classifier.Recognition> results = tensorflow.recognizeImage(croppedBitmap); + + LOGGER.v("%d results", results.size()); + for (final Classifier.Recognition result : results) { + LOGGER.v("Result: " + result.getTitle()); + } + scoreView.setResults(results); + } +} |