diff options
Diffstat (limited to 'tensorflow/examples/android/src/org')
9 files changed, 1339 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java new file mode 100644 index 0000000000..011dc64d16 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java @@ -0,0 +1,74 @@ +/* + * Copyright 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.demo; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** + * A {@link TextureView} that can be adjusted to a specified aspect ratio. + */ +public class AutoFitTextureView extends TextureView { + private int ratioWidth = 0; + private int ratioHeight = 0; + + public AutoFitTextureView(final Context context) { + this(context, null); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(final Context context, final AttributeSet attrs, final int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that + * is, calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(final int width, final int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + ratioWidth = width; + ratioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + final int width = MeasureSpec.getSize(widthMeasureSpec); + final int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == ratioWidth || 0 == ratioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * ratioWidth / ratioHeight) { + setMeasuredDimension(width, width * ratioHeight / ratioWidth); + } else { + setMeasuredDimension(height * ratioWidth / ratioHeight, height); + } + } + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java new file mode 100644 index 0000000000..943dddd254 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/CameraActivity.java @@ -0,0 +1,34 @@ +/* + * Copyright 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.demo; + +import android.app.Activity; +import android.os.Bundle; + +public class CameraActivity extends Activity { + @Override + protected void onCreate(final Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_camera); + if (null == savedInstanceState) { + getFragmentManager() + .beginTransaction() + .replace(R.id.container, CameraConnectionFragment.newInstance()) + .commit(); + } + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java new file mode 100644 index 0000000000..d9a696d9bb --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java @@ -0,0 +1,593 @@ +/* + * Copyright 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.demo; + +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.res.Configuration; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.util.Size; +import android.util.SparseIntArray; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.Toast; + +import org.tensorflow.demo.env.Logger; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +public class CameraConnectionFragment extends Fragment { + private static final Logger LOGGER = new Logger(); + + private RecognitionScoreView scoreView; + + /** + * Conversion from screen rotation to JPEG orientation. + */ + private static final SparseIntArray ORIENTATIONS = new SparseIntArray(); + private static final String FRAGMENT_DIALOG = "dialog"; + + static { + ORIENTATIONS.append(Surface.ROTATION_0, 90); + ORIENTATIONS.append(Surface.ROTATION_90, 0); + ORIENTATIONS.append(Surface.ROTATION_180, 270); + ORIENTATIONS.append(Surface.ROTATION_270, 180); + } + + /** + * {@link android.view.TextureView.SurfaceTextureListener} handles several lifecycle events on a + * {@link TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + @Override + public void onSurfaceTextureAvailable( + final SurfaceTexture texture, final int width, final int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged( + final SurfaceTexture texture, final int width, final int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(final SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(final SurfaceTexture texture) {} + }; + + /** + * ID of the current {@link CameraDevice}. + */ + private String cameraId; + + /** + * An {@link AutoFitTextureView} for camera preview. + */ + private AutoFitTextureView textureView; + + /** + * A {@link CameraCaptureSession } for camera preview. + */ + private CameraCaptureSession captureSession; + + /** + * A reference to the opened {@link CameraDevice}. + */ + private CameraDevice cameraDevice; + + /** + * The {@link android.util.Size} of camera preview. + */ + private Size previewSize; + + /** + * {@link android.hardware.camera2.CameraDevice.StateCallback} + * is called when {@link CameraDevice} changes its state. + */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + @Override + public void onOpened(final CameraDevice cd) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = cd; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(final CameraDevice cd) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + } + + @Override + public void onError(final CameraDevice cd, final int error) { + cameraOpenCloseLock.release(); + cd.close(); + cameraDevice = null; + final Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + /** + * An additional thread for running tasks that shouldn't block the UI. + */ + private HandlerThread backgroundThread; + + /** + * A {@link Handler} for running tasks in the background. + */ + private Handler backgroundHandler; + + /** + * An {@link ImageReader} that handles still image capture. + */ + private ImageReader imageReader; + + /** + * {@link android.hardware.camera2.CaptureRequest.Builder} for the camera preview + */ + private CaptureRequest.Builder previewRequestBuilder; + + /** + * {@link CaptureRequest} generated by {@link #previewRequestBuilder} + */ + private CaptureRequest previewRequest; + + /** + * A {@link Semaphore} to prevent the app from exiting before closing the camera. + */ + private final Semaphore cameraOpenCloseLock = new Semaphore(1); + + /** + * Shows a {@link Toast} on the UI thread. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + Toast.makeText(activity, text, Toast.LENGTH_SHORT).show(); + } + }); + } + } + + /** + * Given {@code choices} of {@code Size}s supported by a camera, chooses the smallest one whose + * width and height are at least as large as the respective requested values, and whose aspect + * ratio matches with the specified value. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param width The minimum desired width + * @param height The minimum desired height + * @param aspectRatio The aspect ratio + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + private static Size chooseOptimalSize( + final Size[] choices, final int width, final int height, final Size aspectRatio) { + // Collect the supported resolutions that are at least as big as the preview Surface + final List<Size> bigEnough = new ArrayList<>(); + for (final Size option : choices) { + // TODO(andrewharp): Choose size intelligently. + if (option.getHeight() == 320 && option.getWidth() == 480) { + LOGGER.i("Adding size: " + option.getWidth() + "x" + option.getHeight()); + bigEnough.add(option); + } else { + LOGGER.i("Not adding size: " + option.getWidth() + "x" + option.getHeight()); + } + } + + // Pick the smallest of those, assuming we found any + if (bigEnough.size() > 0) { + final Size chosenSize = Collections.min(bigEnough, new CompareSizesByArea()); + LOGGER.i("Chosen size: " + chosenSize.getWidth() + "x" + chosenSize.getHeight()); + return chosenSize; + } else { + LOGGER.e("Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static CameraConnectionFragment newInstance() { + return new CameraConnectionFragment(); + } + + @Override + public View onCreateView( + final LayoutInflater inflater, final ViewGroup container, final Bundle savedInstanceState) { + return inflater.inflate(R.layout.camera_connection_fragment, container, false); + } + + @Override + public void onViewCreated(final View view, final Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + scoreView = (RecognitionScoreView) view.findViewById(R.id.results); + } + + @Override + public void onActivityCreated(final Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + /** + * Sets up member variables related to camera. + * + * @param width The width of available size for camera preview + * @param height The height of available size for camera preview + */ + private void setUpCameraOutputs(final int width, final int height) { + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + for (final String cameraId : manager.getCameraIdList()) { + final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + final Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + final StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + + if (map == null) { + continue; + } + + // For still image captures, we use the largest available size. + final Size largest = + Collections.max( + Arrays.asList(map.getOutputSizes(ImageFormat.YUV_420_888)), + new CompareSizesByArea()); + + imageReader = + ImageReader.newInstance( + largest.getWidth(), largest.getHeight(), ImageFormat.YUV_420_888, /*maxImages*/ 2); + + // Danger, W.R.! Attempting to use too large a preview size could exceed the camera + // bus' bandwidth limitation, resulting in gorgeous previews but the storage of + // garbage capture data. + previewSize = + chooseOptimalSize(map.getOutputSizes(SurfaceTexture.class), width, height, largest); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + final int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + + CameraConnectionFragment.this.cameraId = cameraId; + return; + } + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + } + } + + /** + * Opens the camera specified by {@link CameraConnectionFragment#cameraId}. + */ + private void openCamera(final int width, final int height) { + setUpCameraOutputs(width, height); + configureTransform(width, height); + final Activity activity = getActivity(); + final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + /** + * Closes the current {@link CameraDevice}. + */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != imageReader) { + imageReader.close(); + imageReader = null; + } + } catch (final InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** + * Starts a background thread and its {@link Handler}. + */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread("CameraBackground"); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + } + + /** + * Stops the background thread and its {@link Handler}. + */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + } catch (final InterruptedException e) { + LOGGER.e(e, "Exception!"); + } + } + + private final TensorflowImageListener tfPreviewListener = new TensorflowImageListener(); + + private final CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + @Override + public void onCaptureProgressed( + final CameraCaptureSession session, + final CaptureRequest request, + final CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + final CameraCaptureSession session, + final CaptureRequest request, + final TotalCaptureResult result) {} + }; + + /** + * Creates a new {@link CameraCaptureSession} for camera preview. + */ + private void createCameraPreviewSession() { + try { + final SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + final Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + LOGGER.i("Opening camera preview: " + previewSize.getWidth() + "x" + previewSize.getHeight()); + + // Create the reader for the preview frames. + final ImageReader previewReader = + ImageReader.newInstance( + previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2); + + previewReader.setOnImageAvailableListener(tfPreviewListener, backgroundHandler); + previewRequestBuilder.addTarget(previewReader.getSurface()); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface, imageReader.getSurface(), previewReader.getSurface()), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(final CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + // Flash is automatically enabled when necessary. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + } + + @Override + public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (final CameraAccessException e) { + LOGGER.e(e, "Exception!"); + } + + LOGGER.i("Getting assets."); + tfPreviewListener.initialize(getActivity().getAssets(), scoreView); + LOGGER.i("Tensorflow initialized."); + } + + /** + * Configures the necessary {@link android.graphics.Matrix} transformation to `mTextureView`. + * This method should be called after the camera preview size is determined in + * setUpCameraOutputs and also the size of `mTextureView` is fixed. + * + * @param viewWidth The width of `mTextureView` + * @param viewHeight The height of `mTextureView` + */ + private void configureTransform(final int viewWidth, final int viewHeight) { + final Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + final int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + final Matrix matrix = new Matrix(); + final RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + final RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + final float centerX = viewRect.centerX(); + final float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + final float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** + * Compares two {@code Size}s based on their areas. + */ + static class CompareSizesByArea implements Comparator<Size> { + @Override + public int compare(final Size lhs, final Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** + * Shows an error message dialog. + */ + public static class ErrorDialog extends DialogFragment { + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(final String message) { + final ErrorDialog dialog = new ErrorDialog(); + final Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(final Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(final DialogInterface dialogInterface, final int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java new file mode 100644 index 0000000000..60b3037c7d --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/Classifier.java @@ -0,0 +1,87 @@ +package org.tensorflow.demo; + +import android.graphics.Bitmap; +import android.graphics.RectF; + +import java.util.List; + +/** + * Generic interface for interacting with different recognition engines. + */ +public interface Classifier { + /** + * An immutable result returned by a Classifier describing what was recognized. + */ + public class Recognition { + /** + * A unique identifier for what has been recognized. Specific to the class, not the instance of + * the object. + */ + private final String id; + + /** + * Display name for the recognition. + */ + private final String title; + + /** + * A sortable score for how good the recognition is relative to others. Higher should be better. + */ + private final Float confidence; + + /** + * Optional location within the source image for the location of the recognized object. + */ + private final RectF location; + + public Recognition( + final String id, final String title, final Float confidence, final RectF location) { + this.id = id; + this.title = title; + this.confidence = confidence; + this.location = location; + } + + public String getId() { + return id; + } + + public String getTitle() { + return title; + } + + public Float getConfidence() { + return confidence; + } + + public RectF getLocation() { + return new RectF(location); + } + + @Override + public String toString() { + String resultString = ""; + if (id != null) { + resultString += "[" + id + "] "; + } + + if (title != null) { + resultString += title + " "; + } + + if (confidence != null) { + resultString += String.format("(%.1f%%) ", confidence * 100.0f); + } + + if (location != null) { + resultString += location + " "; + } + + return resultString.trim(); + } + } + + List<Recognition> recognizeImage(Bitmap bitmap); + + void close(); +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java new file mode 100644 index 0000000000..961b492a8d --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java @@ -0,0 +1,53 @@ +package org.tensorflow.demo; + +import android.content.Context; +import android.graphics.Canvas; +import android.graphics.Paint; +import android.util.AttributeSet; +import android.util.TypedValue; +import android.view.View; + +import org.tensorflow.demo.Classifier.Recognition; + +import java.util.List; + +public class RecognitionScoreView extends View { + private static final float TEXT_SIZE_DIP = 24; + private List<Recognition> results; + private final float textSizePx; + private final Paint fgPaint; + private final Paint bgPaint; + + public RecognitionScoreView(final Context context, final AttributeSet set) { + super(context, set); + + textSizePx = + TypedValue.applyDimension( + TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); + fgPaint = new Paint(); + fgPaint.setTextSize(textSizePx); + + bgPaint = new Paint(); + bgPaint.setColor(0xcc4285f4); + } + + public void setResults(final List<Recognition> results) { + this.results = results; + postInvalidate(); + } + + @Override + public void onDraw(final Canvas canvas) { + final int x = 10; + int y = (int) (fgPaint.getTextSize() * 1.5f); + + canvas.drawPaint(bgPaint); + + if (results != null) { + for (final Recognition recog : results) { + canvas.drawText(recog.getTitle() + ": " + recog.getConfidence(), x, y, fgPaint); + y += fgPaint.getTextSize() * 1.5f; + } + } + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowClassifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowClassifier.java new file mode 100644 index 0000000000..84a7596ecb --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowClassifier.java @@ -0,0 +1,62 @@ +package org.tensorflow.demo; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.util.Log; + +import java.util.ArrayList; +import java.util.List; +import java.util.StringTokenizer; + +/** + * JNI wrapper class for the Tensorflow native code. + */ +public class TensorflowClassifier implements Classifier { + private static final String TAG = "TensorflowClassifier"; + + // jni native methods. + public native int initializeTensorflow( + AssetManager assetManager, + String model, + String labels, + int numClasses, + int inputSize, + int imageMean); + + private native String classifyImageBmp(Bitmap bitmap); + + private native String classifyImageRgb(int[] output, int width, int height); + + static { + System.loadLibrary("tensorflow_demo"); + } + + @Override + public List<Recognition> recognizeImage(final Bitmap bitmap) { + final ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); + for (final String result : classifyImageBmp(bitmap).split("\n")) { + Log.i(TAG, "Parsing [" + result + "]"); + + // Clean up the string as needed + final StringTokenizer st = new StringTokenizer(result); + if (!st.hasMoreTokens()) { + continue; + } + + final String id = st.nextToken(); + final String confidenceString = st.nextToken(); + final float confidence = Float.parseFloat(confidenceString); + + final String title = + result.substring(id.length() + confidenceString.length() + 2, result.length()); + + if (!title.isEmpty()) { + recognitions.add(new Recognition(id, title, confidence, null)); + } + } + return recognitions; + } + + @Override + public void close() {} +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java new file mode 100644 index 0000000000..940fbc6771 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorflowImageListener.java @@ -0,0 +1,147 @@ +package org.tensorflow.demo; + +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.Bitmap.Config; +import android.graphics.Canvas; +import android.graphics.Matrix; +import android.media.Image; +import android.media.Image.Plane; +import android.media.ImageReader; +import android.media.ImageReader.OnImageAvailableListener; + +import junit.framework.Assert; + +import org.tensorflow.demo.env.ImageUtils; +import org.tensorflow.demo.env.Logger; + +import java.nio.ByteBuffer; +import java.util.List; + +/** + * Class that takes in preview frames and converts the image to Bitmaps to process with Tensorflow. + */ +public class TensorflowImageListener implements OnImageAvailableListener { + private static final Logger LOGGER = new Logger(); + + private static final boolean SAVE_PREVIEW_BITMAP = false; + + private static final String MODEL_FILE = "file:///android_asset/tensorflow_inception_graph.pb"; + private static final String LABEL_FILE = + "file:///android_asset/imagenet_comp_graph_label_strings.txt"; + + private static final int NUM_CLASSES = 1001; + private static final int INPUT_SIZE = 224; + private static final int IMAGE_MEAN = 117; + + // TODO(andrewharp): Get orientation programatically. + private final int screenRotation = 90; + + private final TensorflowClassifier tensorflow = new TensorflowClassifier(); + + private int previewWidth = 0; + private int previewHeight = 0; + private byte[] yuvBytes = null; + private int[] rgbBytes = null; + private Bitmap rgbFrameBitmap = null; + private Bitmap croppedBitmap = null; + + private RecognitionScoreView scoreView; + + public void initialize(final AssetManager assetManager, final RecognitionScoreView scoreView) { + tensorflow.initializeTensorflow( + assetManager, MODEL_FILE, LABEL_FILE, NUM_CLASSES, INPUT_SIZE, IMAGE_MEAN); + this.scoreView = scoreView; + } + + private void drawResizedBitmap(final Bitmap src, final Bitmap dst) { + Assert.assertEquals(dst.getWidth(), dst.getHeight()); + final float minDim = Math.min(src.getWidth(), src.getHeight()); + + final Matrix matrix = new Matrix(); + + // We only want the center square out of the original rectangle. + final float translateX = -Math.max(0, (src.getWidth() - minDim) / 2); + final float translateY = -Math.max(0, (src.getHeight() - minDim) / 2); + matrix.preTranslate(translateX, translateY); + + final float scaleFactor = dst.getHeight() / minDim; + matrix.postScale(scaleFactor, scaleFactor); + + // Rotate around the center if necessary. + if (screenRotation != 0) { + matrix.postTranslate(-dst.getWidth() / 2.0f, -dst.getHeight() / 2.0f); + matrix.postRotate(screenRotation); + matrix.postTranslate(dst.getWidth() / 2.0f, dst.getHeight() / 2.0f); + } + + final Canvas canvas = new Canvas(dst); + canvas.drawBitmap(src, matrix, null); + } + + @Override + public void onImageAvailable(final ImageReader reader) { + Image image = null; + try { + image = reader.acquireLatestImage(); + + if (image == null) { + return; + } + + // Initialize the storage bitmaps once when the resolution is known. + if (previewWidth != image.getWidth() || previewHeight != image.getHeight()) { + LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight); + previewWidth = image.getWidth(); + previewHeight = image.getHeight(); + rgbBytes = new int[previewWidth * previewHeight]; + yuvBytes = new byte[ImageUtils.getYUVByteSize(previewWidth, previewHeight)]; + rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888); + croppedBitmap = Bitmap.createBitmap(INPUT_SIZE, INPUT_SIZE, Config.ARGB_8888); + } + + final Plane[] planes = image.getPlanes(); + int position = 0; + + // Copy the bytes from the Image into a buffer for easier conversion to RGB. + // TODO(andrewharp): It may not be correct to do it this way. + final int[] planeOrder = {0, 2}; + for (int i = 0; i < planeOrder.length; ++i) { + final Plane plane = planes[planeOrder[i]]; + final ByteBuffer buffer = plane.getBuffer(); + + buffer.rewind(); + final int readAmount = buffer.remaining(); + + buffer.get(yuvBytes, position, readAmount); + position += readAmount; + } + + image.close(); + + ImageUtils.convertYUV420SPToARGB8888(yuvBytes, rgbBytes, previewWidth, previewHeight, false); + } catch (final Exception e) { + if (image != null) { + image.close(); + } + LOGGER.e(e, "Exception!"); + return; + } + + rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight); + drawResizedBitmap(rgbFrameBitmap, croppedBitmap); + + // For examining the actual TF input. + if (SAVE_PREVIEW_BITMAP) { + ImageUtils.saveBitmap(croppedBitmap); + } + + final List<Classifier.Recognition> results = tensorflow.recognizeImage(croppedBitmap); + + LOGGER.v("%d results", results.size()); + for (final Classifier.Recognition result : results) { + LOGGER.v("Result: " + result.getTitle()); + } + scoreView.setResults(results); + } +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java new file mode 100644 index 0000000000..78f818f734 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/ImageUtils.java @@ -0,0 +1,113 @@ +package org.tensorflow.demo.env; + +import android.graphics.Bitmap; +import android.os.Environment; + +import java.io.File; +import java.io.FileOutputStream; + +/** + * Utility class for manipulating images. + **/ +public class ImageUtils { + @SuppressWarnings("unused") + private static final Logger LOGGER = new Logger(); + + /** + * Utility method to compute the allocated size in bytes of a YUV420SP image + * of the given dimensions. + */ + public static int getYUVByteSize(final int width, final int height) { + // The luminance plane requires 1 byte per pixel. + final int ySize = width * height; + + // The UV plane works on 2x2 blocks, so dimensions with odd size must be rounded up. + // Each 2x2 block takes 2 bytes to encode, one each for U and V. + final int uvSize = ((width + 1) / 2) * ((height + 1) / 2) * 2; + + return ySize + uvSize; + } + + /** + * Saves a Bitmap object to disk for analysis. + * + * @param bitmap The bitmap to save. + */ + public static void saveBitmap(final Bitmap bitmap) { + final String root = + Environment.getExternalStorageDirectory().getAbsolutePath() + File.separator + "tensorflow"; + LOGGER.i("Saving %dx%d bitmap to %s.", bitmap.getWidth(), bitmap.getHeight(), root); + final File myDir = new File(root); + + if (!myDir.mkdirs()) { + LOGGER.i("Make dir failed"); + } + + final String fname = "preview.png"; + final File file = new File(myDir, fname); + if (file.exists()) { + file.delete(); + } + try { + final FileOutputStream out = new FileOutputStream(file); + bitmap.compress(Bitmap.CompressFormat.PNG, 99, out); + out.flush(); + out.close(); + } catch (final Exception e) { + LOGGER.e(e, "Exception!"); + } + } + + /** + * Converts YUV420 semi-planar data to ARGB 8888 data using the supplied width + * and height. The input and output must already be allocated and non-null. + * For efficiency, no error checking is performed. + * + * @param input The array of YUV 4:2:0 input data. + * @param output A pre-allocated array for the ARGB 8:8:8:8 output data. + * @param width The width of the input image. + * @param height The height of the input image. + * @param halfSize If true, downsample to 50% in each dimension, otherwise not. + */ + public static native void convertYUV420SPToARGB8888( + byte[] input, int[] output, int width, int height, boolean halfSize); + + /** + * Converts YUV420 semi-planar data to RGB 565 data using the supplied width + * and height. The input and output must already be allocated and non-null. + * For efficiency, no error checking is performed. + * + * @param input The array of YUV 4:2:0 input data. + * @param output A pre-allocated array for the RGB 5:6:5 output data. + * @param width The width of the input image. + * @param height The height of the input image. + */ + public static native void convertYUV420SPToRGB565( + byte[] input, byte[] output, int width, int height); + + /** + * Converts 32-bit ARGB8888 image data to YUV420SP data. This is useful, for + * instance, in creating data to feed the classes that rely on raw camera + * preview frames. + * + * @param input An array of input pixels in ARGB8888 format. + * @param output A pre-allocated array for the YUV420SP output data. + * @param width The width of the input image. + * @param height The height of the input image. + */ + public static native void convertARGB8888ToYUV420SP( + int[] input, byte[] output, int width, int height); + + /** + * Converts 16-bit RGB565 image data to YUV420SP data. This is useful, for + * instance, in creating data to feed the classes that rely on raw camera + * preview frames. + * + * @param input An array of input pixels in RGB565 format. + * @param output A pre-allocated array for the YUV420SP output data. + * @param width The width of the input image. + * @param height The height of the input image. + */ + public static native void convertRGB565ToYUV420SP( + byte[] input, byte[] output, int width, int height); +} diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java new file mode 100644 index 0000000000..697c231176 --- /dev/null +++ b/tensorflow/examples/android/src/org/tensorflow/demo/env/Logger.java @@ -0,0 +1,176 @@ +package org.tensorflow.demo.env; + +import android.util.Log; + +import java.util.HashSet; +import java.util.Set; + +/** + * Wrapper for the platform log function, allows convenient message prefixing and log disabling. + */ +public final class Logger { + private static final String DEFAULT_TAG = "tensorflow"; + private static final int DEFAULT_MIN_LOG_LEVEL = Log.DEBUG; + + // Classes to be ignored when examining the stack trace + private static final Set<String> IGNORED_CLASS_NAMES; + + static { + IGNORED_CLASS_NAMES = new HashSet<String>(3); + IGNORED_CLASS_NAMES.add("dalvik.system.VMStack"); + IGNORED_CLASS_NAMES.add("java.lang.Thread"); + IGNORED_CLASS_NAMES.add(Logger.class.getCanonicalName()); + } + + private final String tag; + private final String messagePrefix; + private int minLogLevel = DEFAULT_MIN_LOG_LEVEL; + + /** + * Creates a Logger using the class name as the message prefix. + * + * @param clazz the simple name of this class is used as the message prefix. + */ + public Logger(final Class<?> clazz) { + this(clazz.getSimpleName()); + } + + /** + * Creates a Logger using the specified message prefix. + * + * @param messagePrefix is prepended to the text of every message. + */ + public Logger(final String messagePrefix) { + this(DEFAULT_TAG, messagePrefix); + } + + /** + * Creates a Logger with a custom tag and a custom message prefix. If the message prefix + * is set to <pre>null</pre>, the caller's class name is used as the prefix. + * + * @param tag identifies the source of a log message. + * @param messagePrefix prepended to every message if non-null. If null, the name of the caller is + * being used + */ + public Logger(final String tag, final String messagePrefix) { + this.tag = tag; + final String prefix = messagePrefix == null ? getCallerSimpleName() : messagePrefix; + this.messagePrefix = (prefix.length() > 0) ? prefix + ": " : prefix; + } + + /** + * Creates a Logger using the caller's class name as the message prefix. + */ + public Logger() { + this(DEFAULT_TAG, null); + } + + /** + * Creates a Logger using the caller's class name as the message prefix. + */ + public Logger(final int minLogLevel) { + this(DEFAULT_TAG, null); + this.minLogLevel = minLogLevel; + } + + public void setMinLogLevel(final int minLogLevel) { + this.minLogLevel = minLogLevel; + } + + public boolean isLoggable(final int logLevel) { + return logLevel >= minLogLevel || Log.isLoggable(tag, logLevel); + } + + /** + * Return caller's simple name. + * + * Android getStackTrace() returns an array that looks like this: + * stackTrace[0]: dalvik.system.VMStack + * stackTrace[1]: java.lang.Thread + * stackTrace[2]: com.google.android.apps.unveil.env.UnveilLogger + * stackTrace[3]: com.google.android.apps.unveil.BaseApplication + * + * This function returns the simple version of the first non-filtered name. + * + * @return caller's simple name + */ + private static String getCallerSimpleName() { + // Get the current callstack so we can pull the class of the caller off of it. + final StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace(); + + for (final StackTraceElement elem : stackTrace) { + final String className = elem.getClassName(); + if (!IGNORED_CLASS_NAMES.contains(className)) { + // We're only interested in the simple name of the class, not the complete package. + final String[] classParts = className.split("\\."); + return classParts[classParts.length - 1]; + } + } + + return Logger.class.getSimpleName(); + } + + private String toMessage(final String format, final Object... args) { + return messagePrefix + (args.length > 0 ? String.format(format, args) : format); + } + + public void v(final String format, final Object... args) { + if (isLoggable(Log.VERBOSE)) { + Log.v(tag, toMessage(format, args)); + } + } + + public void v(final Throwable t, final String format, final Object... args) { + if (isLoggable(Log.VERBOSE)) { + Log.v(tag, toMessage(format, args), t); + } + } + + public void d(final String format, final Object... args) { + if (isLoggable(Log.DEBUG)) { + Log.d(tag, toMessage(format, args)); + } + } + + public void d(final Throwable t, final String format, final Object... args) { + if (isLoggable(Log.DEBUG)) { + Log.d(tag, toMessage(format, args), t); + } + } + + public void i(final String format, final Object... args) { + if (isLoggable(Log.INFO)) { + Log.i(tag, toMessage(format, args)); + } + } + + public void i(final Throwable t, final String format, final Object... args) { + if (isLoggable(Log.INFO)) { + Log.i(tag, toMessage(format, args), t); + } + } + + public void w(final String format, final Object... args) { + if (isLoggable(Log.WARN)) { + Log.w(tag, toMessage(format, args)); + } + } + + public void w(final Throwable t, final String format, final Object... args) { + if (isLoggable(Log.WARN)) { + Log.w(tag, toMessage(format, args), t); + } + } + + public void e(final String format, final Object... args) { + if (isLoggable(Log.ERROR)) { + Log.e(tag, toMessage(format, args)); + } + } + + public void e(final Throwable t, final String format, final Object... args) { + if (isLoggable(Log.ERROR)) { + Log.e(tag, toMessage(format, args), t); + } + } +} |