path: root/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
diff options
authorGravatar Andrew Harp <andrewharp@google.com>2017-01-11 14:05:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-11 14:26:29 -0800
commit6426c4efd825e80579c478a61f66579c1c3bcef1 (patch)
tree01859f3fcecdc4de174a89aae268270fd0e28cb5 /tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
parent54e2be2f62dc34fbaef0fd052ad28d656ede6535 (diff)
Android: add image stylization example demo based on "A Learned Representation For Artistic Style"
Change: 144247143
Diffstat (limited to 'tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java')
1 files changed, 662 insertions, 0 deletions
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
new file mode 100644
index 0000000000..8a3c7a4ef9
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
@@ -0,0 +1,662 @@
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.tensorflow.demo;
+import android.content.Context;
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+import android.graphics.Bitmap.Config;
+import android.graphics.BitmapFactory;
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Paint.Style;
+import android.graphics.Rect;
+import android.graphics.Typeface;
+import android.media.Image;
+import android.media.Image.Plane;
+import android.media.ImageReader;
+import android.media.ImageReader.OnImageAvailableListener;
+import android.os.Bundle;
+import android.os.SystemClock;
+import android.os.Trace;
+import android.util.Size;
+import android.util.TypedValue;
+import android.view.Display;
+import android.view.MotionEvent;
+import android.view.View;
+import android.view.View.OnClickListener;
+import android.view.View.OnTouchListener;
+import android.view.ViewGroup;
+import android.widget.BaseAdapter;
+import android.widget.Button;
+import android.widget.GridView;
+import android.widget.ImageView;
+import android.widget.Toast;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Vector;
+import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
+import org.tensorflow.demo.OverlayView.DrawCallback;
+import org.tensorflow.demo.env.BorderedText;
+import org.tensorflow.demo.env.ImageUtils;
+import org.tensorflow.demo.env.Logger;
+import org.tensorflow.demo.R;
+ * Sample activity that stylizes the camera preview according to "A Learned Representation For
+ * Artistic Style" (https://arxiv.org/abs/1610.07629)
+ */
+public class StylizeActivity extends CameraActivity implements OnImageAvailableListener {
+ static {
+ System.loadLibrary("tensorflow_demo");
+ }
+ private static final Logger LOGGER = new Logger();
+ private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";
+ private static final String INPUT_NODE = "input:0";
+ private static final String STYLE_NODE = "style_num:0";
+ private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid";
+ private static final int NUM_STYLES = 26;
+ private static final boolean SAVE_PREVIEW_BITMAP = false;
+ // Whether to actively manipulate non-selected sliders so that sum of activations always appears
+ // to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless.
+ private static final boolean NORMALIZE_SLIDERS = true;
+ private static final float TEXT_SIZE_DIP = 12;
+ private static final boolean DEBUG_MODEL = false;
+ private static final int[] SIZES = {32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024};
+ // Start at a medium size, but let the user step up through smaller sizes so they don't get
+ // immediately stuck processing a large image.
+ private int desiredSizeIndex = -1;
+ private int desiredSize = 256;
+ private int initializedSize = 0;
+ private Integer sensorOrientation;
+ private int previewWidth = 0;
+ private int previewHeight = 0;
+ private byte[][] yuvBytes;
+ private int[] rgbBytes = null;
+ private Bitmap rgbFrameBitmap = null;
+ private Bitmap croppedBitmap = null;
+ private final float[] styleVals = new float[NUM_STYLES];
+ private int[] intValues;
+ private float[] floatValues;
+ private int frameNum = 0;
+ private Bitmap cropCopyBitmap;
+ private Bitmap textureCopyBitmap;
+ private boolean computing = false;
+ private Matrix frameToCropTransform;
+ private Matrix cropToFrameTransform;
+ private BorderedText borderedText;
+ private long lastProcessingTimeMs;
+ private TensorFlowInferenceInterface inferenceInterface;
+ private int lastOtherStyle = 1;
+ private boolean allZero = false;
+ private ImageGridAdapter adapter;
+ private GridView grid;
+ private final OnTouchListener gridTouchAdapter =
+ new OnTouchListener() {
+ ImageSlider slider = null;
+ @Override
+ public boolean onTouch(final View v, final MotionEvent event) {
+ switch (event.getActionMasked()) {
+ case MotionEvent.ACTION_DOWN:
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ final ImageSlider child = adapter.items[i];
+ final Rect rect = new Rect();
+ child.getHitRect(rect);
+ if (rect.contains((int) event.getX(), (int) event.getY())) {
+ slider = child;
+ slider.setHilighted(true);
+ }
+ }
+ break;
+ case MotionEvent.ACTION_MOVE:
+ if (slider != null) {
+ final Rect rect = new Rect();
+ slider.getHitRect(rect);
+ final float newSliderVal =
+ (float)
+ Math.min(
+ 1.0,
+ Math.max(
+ 0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight()));
+ setStyle(slider, newSliderVal);
+ }
+ break;
+ case MotionEvent.ACTION_UP:
+ if (slider != null) {
+ slider.setHilighted(false);
+ slider = null;
+ }
+ break;
+ }
+ return true;
+ }
+ };
+ @Override
+ public void onCreate(final Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ }
+ @Override
+ protected int getLayoutId() {
+ return R.layout.camera_connection_fragment_stylize;
+ }
+ @Override
+ protected int getDesiredPreviewFrameSize() {
+ return SIZES[SIZES.length - 1];
+ }
+ public static Bitmap getBitmapFromAsset(final Context context, final String filePath) {
+ final AssetManager assetManager = context.getAssets();
+ Bitmap bitmap = null;
+ try {
+ final InputStream inputStream = assetManager.open(filePath);
+ bitmap = BitmapFactory.decodeStream(inputStream);
+ } catch (final IOException e) {
+ LOGGER.e("Error opening bitmap!", e);
+ }
+ return bitmap;
+ }
+ private class ImageSlider extends ImageView {
+ private float value = 0.0f;
+ private boolean hilighted = false;
+ private final Paint boxPaint;
+ private final Paint linePaint;
+ public ImageSlider(final Context context) {
+ super(context);
+ value = 0.0f;
+ boxPaint = new Paint();
+ boxPaint.setColor(Color.BLACK);
+ boxPaint.setAlpha(128);
+ linePaint = new Paint();
+ linePaint.setColor(Color.WHITE);
+ linePaint.setStrokeWidth(10.0f);
+ linePaint.setStyle(Style.STROKE);
+ }
+ @Override
+ public void onDraw(final Canvas canvas) {
+ super.onDraw(canvas);
+ final float y = (1.0f - value) * canvas.getHeight();
+ // If all sliders are zero, don't bother shading anything.
+ if (!allZero) {
+ canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint);
+ }
+ if (value > 0.0f) {
+ canvas.drawLine(0, y, canvas.getWidth(), y, linePaint);
+ }
+ if (hilighted) {
+ canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint);
+ }
+ }
+ @Override
+ protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
+ }
+ public void setValue(final float value) {
+ this.value = value;
+ postInvalidate();
+ }
+ public void setHilighted(final boolean highlighted) {
+ this.hilighted = highlighted;
+ this.postInvalidate();
+ }
+ }
+ private class ImageGridAdapter extends BaseAdapter {
+ final ImageSlider[] items = new ImageSlider[NUM_STYLES];
+ final ArrayList<Button> buttons = new ArrayList<Button>();
+ {
+ final Button sizeButton =
+ new Button(StylizeActivity.this) {
+ @Override
+ protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
+ }
+ };
+ sizeButton.setText("" + desiredSize);
+ sizeButton.setOnClickListener(
+ new OnClickListener() {
+ @Override
+ public void onClick(final View v) {
+ desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length;
+ desiredSize = SIZES[desiredSizeIndex];
+ sizeButton.setText("" + desiredSize);
+ sizeButton.postInvalidate();
+ }
+ });
+ final Button saveButton =
+ new Button(StylizeActivity.this) {
+ @Override
+ protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
+ }
+ };
+ saveButton.setText("Save");
+ saveButton.setOnClickListener(
+ new OnClickListener() {
+ @Override
+ public void onClick(final View v) {
+ if (textureCopyBitmap != null) {
+ // TODO(andrewharp): Save as jpeg with guaranteed unique filename.
+ ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png");
+ Toast.makeText(
+ StylizeActivity.this,
+ "Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png",
+ .show();
+ }
+ }
+ });
+ buttons.add(sizeButton);
+ buttons.add(saveButton);
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ LOGGER.v("Creating item %d", i);
+ if (items[i] == null) {
+ final ImageSlider slider = new ImageSlider(StylizeActivity.this);
+ final Bitmap bm =
+ getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg");
+ slider.setImageBitmap(bm);
+ items[i] = slider;
+ }
+ }
+ }
+ @Override
+ public int getCount() {
+ return buttons.size() + NUM_STYLES;
+ }
+ @Override
+ public Object getItem(final int position) {
+ if (position < buttons.size()) {
+ return buttons.get(position);
+ } else {
+ return items[position - buttons.size()];
+ }
+ }
+ @Override
+ public long getItemId(final int position) {
+ return getItem(position).hashCode();
+ }
+ @Override
+ public View getView(final int position, final View convertView, final ViewGroup parent) {
+ if (convertView != null) {
+ return convertView;
+ }
+ return (View) getItem(position);
+ }
+ }
+ @Override
+ public void onPreviewSizeChosen(final Size size, final int rotation) {
+ final float textSizePx =
+ TypedValue.applyDimension(
+ TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
+ borderedText = new BorderedText(textSizePx);
+ borderedText.setTypeface(Typeface.MONOSPACE);
+ inferenceInterface = new TensorFlowInferenceInterface();
+ inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
+ previewWidth = size.getWidth();
+ previewHeight = size.getHeight();
+ final Display display = getWindowManager().getDefaultDisplay();
+ final int screenOrientation = display.getRotation();
+ LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
+ sensorOrientation = rotation + screenOrientation;
+ addCallback(
+ new DrawCallback() {
+ @Override
+ public void drawCallback(final Canvas canvas) {
+ renderDebug(canvas);
+ }
+ });
+ adapter = new ImageGridAdapter();
+ grid = (GridView) findViewById(R.id.grid_layout);
+ grid.setAdapter(adapter);
+ grid.setOnTouchListener(gridTouchAdapter);
+ setStyle(adapter.items[0], 1.0f);
+ }
+ private void setStyle(final ImageSlider slider, final float value) {
+ slider.setValue(value);
+ // Slider vals correspond directly to the input tensor vals, and normalization is visually
+ // maintained by remanipulating non-selected sliders.
+ float otherSum = 0.0f;
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ if (adapter.items[i] != slider) {
+ otherSum += adapter.items[i].value;
+ }
+ }
+ if (otherSum > 0.0) {
+ float highestOtherVal = 0;
+ final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f;
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ final ImageSlider child = adapter.items[i];
+ if (child == slider) {
+ continue;
+ }
+ final float newVal = child.value * factor;
+ child.setValue(newVal > 0.01f ? newVal : 0.0f);
+ if (child.value > highestOtherVal) {
+ lastOtherStyle = i;
+ highestOtherVal = child.value;
+ }
+ }
+ } else {
+ // Everything else is 0, so just pick a suitable slider to push up when the
+ // selected one goes down.
+ if (adapter.items[lastOtherStyle] == slider) {
+ lastOtherStyle = lastOtherStyle + 1 % NUM_STYLES;
+ }
+ adapter.items[lastOtherStyle].setValue(1.0f - value);
+ }
+ }
+ final boolean lastAllZero = allZero;
+ float sum = 0.0f;
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ sum += adapter.items[i].value;
+ }
+ allZero = sum == 0.0f;
+ // Now update the values used for the input tensor. If nothing is set, mix in everything
+ // equally. Otherwise everything is normalized to sum to 1.0.
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum;
+ if (lastAllZero != allZero) {
+ adapter.items[i].postInvalidate();
+ }
+ }
+ }
+ @Override
+ public void onImageAvailable(final ImageReader reader) {
+ Image image = null;
+ try {
+ image = reader.acquireLatestImage();
+ if (image == null) {
+ return;
+ }
+ if (computing) {
+ image.close();
+ return;
+ }
+ if (desiredSize != initializedSize) {
+ "Initializing at size preview size %dx%d, stylize size %d",
+ previewWidth, previewHeight, desiredSize);
+ rgbBytes = new int[previewWidth * previewHeight];
+ rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
+ croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
+ frameToCropTransform =
+ ImageUtils.getTransformationMatrix(
+ previewWidth, previewHeight,
+ desiredSize, desiredSize,
+ sensorOrientation, true);
+ cropToFrameTransform = new Matrix();
+ frameToCropTransform.invert(cropToFrameTransform);
+ yuvBytes = new byte[3][];
+ intValues = new int[desiredSize * desiredSize];
+ floatValues = new float[desiredSize * desiredSize * 3];
+ initializedSize = desiredSize;
+ }
+ computing = true;
+ Trace.beginSection("imageAvailable");
+ final Plane[] planes = image.getPlanes();
+ fillBytes(planes, yuvBytes);
+ final int yRowStride = planes[0].getRowStride();
+ final int uvRowStride = planes[1].getRowStride();
+ final int uvPixelStride = planes[1].getPixelStride();
+ ImageUtils.convertYUV420ToARGB8888(
+ yuvBytes[0],
+ yuvBytes[1],
+ yuvBytes[2],
+ rgbBytes,
+ previewWidth,
+ previewHeight,
+ yRowStride,
+ uvRowStride,
+ uvPixelStride,
+ false);
+ image.close();
+ } catch (final Exception e) {
+ if (image != null) {
+ image.close();
+ }
+ LOGGER.e(e, "Exception!");
+ Trace.endSection();
+ return;
+ }
+ rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight);
+ final Canvas canvas = new Canvas(croppedBitmap);
+ canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
+ // For examining the actual TF input.
+ ImageUtils.saveBitmap(croppedBitmap);
+ }
+ runInBackground(
+ new Runnable() {
+ @Override
+ public void run() {
+ cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
+ final long startTime = SystemClock.uptimeMillis();
+ stylizeImage(croppedBitmap);
+ lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
+ textureCopyBitmap = Bitmap.createBitmap(croppedBitmap);
+ requestRender();
+ computing = false;
+ }
+ });
+ Trace.endSection();
+ }
+ String outputNode = "";
+ private void stylizeImage(final Bitmap bitmap) {
+ ++frameNum;
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ if (DEBUG_MODEL) {
+ // Create a white square that steps through a black background 1 pixel per frame.
+ final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth();
+ final int centerY = bitmap.getHeight() / 2;
+ final int squareSize = 10;
+ for (int i = 0; i < intValues.length; ++i) {
+ final int x = i % bitmap.getWidth();
+ final int y = i / bitmap.getHeight();
+ final float val =
+ Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f;
+ floatValues[i * 3] = val;
+ floatValues[i * 3 + 1] = val;
+ floatValues[i * 3 + 2] = val;
+ }
+ } else {
+ for (int i = 0; i < intValues.length; ++i) {
+ final int val = intValues[i];
+ floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
+ floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
+ floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
+ }
+ }
+ // Copy the input data into TensorFlow.
+ inferenceInterface.fillNodeFloat(
+ INPUT_NODE, new int[] {1, bitmap.getWidth(), bitmap.getHeight(), 3}, floatValues);
+ inferenceInterface.fillNodeFloat(STYLE_NODE, new int[] {NUM_STYLES}, styleVals);
+ inferenceInterface.runInference(new String[] {OUTPUT_NODE});
+ inferenceInterface.readNodeFloat(OUTPUT_NODE, floatValues);
+ for (int i = 0; i < intValues.length; ++i) {
+ intValues[i] =
+ 0xFF000000
+ | (((int) (floatValues[i * 3] * 255)) << 16)
+ | (((int) (floatValues[i * 3 + 1] * 255)) << 8)
+ | ((int) (floatValues[i * 3 + 2] * 255));
+ }
+ bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ }
+ @Override
+ public void onSetDebug(final boolean debug) {
+ inferenceInterface.enableStatLogging(debug);
+ }
+ private void renderDebug(final Canvas canvas) {
+ // TODO(andrewharp): move result display to its own View instead of using debug overlay.
+ final Bitmap texture = textureCopyBitmap;
+ if (texture != null) {
+ final Matrix matrix = new Matrix();
+ final float scaleFactor =
+ ? 4.0f
+ : Math.min(
+ (float) canvas.getWidth() / texture.getWidth(),
+ (float) canvas.getHeight() / texture.getHeight());
+ matrix.postScale(scaleFactor, scaleFactor);
+ canvas.drawBitmap(texture, matrix, new Paint());
+ }
+ if (!isDebug()) {
+ return;
+ }
+ final Bitmap copy = cropCopyBitmap;
+ if (copy == null) {
+ return;
+ }
+ canvas.drawColor(0x55000000);
+ final Matrix matrix = new Matrix();
+ final float scaleFactor = 2;
+ matrix.postScale(scaleFactor, scaleFactor);
+ matrix.postTranslate(
+ canvas.getWidth() - copy.getWidth() * scaleFactor,
+ canvas.getHeight() - copy.getHeight() * scaleFactor);
+ canvas.drawBitmap(copy, matrix, new Paint());
+ final Vector<String> lines = new Vector<String>();
+ final String[] statLines = inferenceInterface.getStatString().split("\n");
+ for (final String line : statLines) {
+ lines.add(line);
+ }
+ lines.add("");
+ lines.add("Frame: " + previewWidth + "x" + previewHeight);
+ lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
+ lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
+ lines.add("Rotation: " + sensorOrientation);
+ lines.add("Inference time: " + lastProcessingTimeMs + "ms");
+ lines.add("Desired size: " + desiredSize);
+ lines.add("Initialized size: " + initializedSize);
+ borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
+ }