aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2017-01-11 14:05:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-11 14:26:29 -0800
commit6426c4efd825e80579c478a61f66579c1c3bcef1 (patch)
tree01859f3fcecdc4de174a89aae268270fd0e28cb5
parent54e2be2f62dc34fbaef0fd052ad28d656ede6535 (diff)
Android: add image stylization example demo based on "A Learned Representation For Artistic Style"
Change: 144247143
-rw-r--r--WORKSPACE7
-rw-r--r--tensorflow/examples/android/AndroidManifest.xml11
-rw-r--r--tensorflow/examples/android/BUILD1
-rw-r--r--tensorflow/examples/android/README.md4
-rw-r--r--tensorflow/examples/android/bin/AndroidManifest.xml64
-rw-r--r--tensorflow/examples/android/res/layout/camera_connection_fragment_stylize.xml51
-rw-r--r--tensorflow/examples/android/res/values/base-strings.xml1
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java662
8 files changed, 800 insertions, 1 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 9d3622878c..e0931512f4 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -38,6 +38,13 @@ new_http_archive(
sha256 = "b4c178fd6236dcf0a20d25d07c45eebe85281263978c6a6f1dfc49d75befc45f"
)
+new_http_archive(
+ name = "stylize",
+ build_file = "models.BUILD",
+ url = "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip",
+ sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa"
+)
+
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
new_http_archive(
diff --git a/tensorflow/examples/android/AndroidManifest.xml b/tensorflow/examples/android/AndroidManifest.xml
index e388734564..9f229d8b9d 100644
--- a/tensorflow/examples/android/AndroidManifest.xml
+++ b/tensorflow/examples/android/AndroidManifest.xml
@@ -41,7 +41,7 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
-
+
<activity android:name="org.tensorflow.demo.DetectorActivity"
android:screenOrientation="portrait"
android:label="@string/activity_name_detection">
@@ -50,6 +50,15 @@
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
+
+ <activity android:name="org.tensorflow.demo.StylizeActivity"
+ android:screenOrientation="portrait"
+ android:label="@string/activity_name_stylize">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
</application>
</manifest>
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 3ba3a494ab..0c1cea5fc3 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -66,6 +66,7 @@ android_binary(
"//tensorflow/examples/android/assets:asset_files",
"@inception5h//:model_files",
"@mobile_multibox//:model_files",
+ "@stylize//:model_files",
],
assets_dir = "",
custom_package = "org.tensorflow.demo",
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index 79f543fb74..fbbe9f276b 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -22,6 +22,10 @@ existing application.
Demonstrates a model based on [Scalable Object Detection
using Deep Neural Networks](https://arxiv.org/abs/1312.2249) to
localize and track people in the camera preview in real-time.
+3. [TF Stylize](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java):
+ Uses a model based on [A Learned Representation For Artistic Style]
+ (https://arxiv.org/abs/1610.07629) to restyle the camera preview image
+ to that of a number of different artists.
## Prebuilt APK:
diff --git a/tensorflow/examples/android/bin/AndroidManifest.xml b/tensorflow/examples/android/bin/AndroidManifest.xml
new file mode 100644
index 0000000000..d4792bc482
--- /dev/null
+++ b/tensorflow/examples/android/bin/AndroidManifest.xml
@@ -0,0 +1,64 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="org.tensorflow.demo_internal">
+
+ <uses-permission android:name="android.permission.CAMERA" />
+ <uses-feature android:name="android.hardware.camera" />
+ <uses-feature android:name="android.hardware.camera.autofocus" />
+ <uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
+
+ <uses-sdk
+ android:minSdkVersion="21"
+ android:targetSdkVersion="23" />
+
+ <application android:allowBackup="true"
+ android:debuggable="true"
+ android:label="@string/app_name"
+ android:icon="@drawable/ic_launcher"
+ android:theme="@style/MaterialTheme">
+
+ <activity android:name="org.tensorflow.demo.ClassifierActivity"
+ android:screenOrientation="portrait"
+ android:label="@string/activity_name_classification">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
+
+ <activity android:name="org.tensorflow.demo.DetectorActivity"
+ android:screenOrientation="portrait"
+ android:label="@string/activity_name_detection">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
+
+ <activity android:name="org.tensorflow.demo.StylizeActivity"
+ android:screenOrientation="portrait"
+ android:label="@string/activity_name_stylize">
+ <intent-filter>
+ <action android:name="android.intent.action.MAIN" />
+ <category android:name="android.intent.category.LAUNCHER" />
+ </intent-filter>
+ </activity>
+ </application>
+
+</manifest>
diff --git a/tensorflow/examples/android/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/examples/android/res/layout/camera_connection_fragment_stylize.xml
new file mode 100644
index 0000000000..1cdb24cab0
--- /dev/null
+++ b/tensorflow/examples/android/res/layout/camera_connection_fragment_stylize.xml
@@ -0,0 +1,51 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ android:orientation="vertical"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent">
+ <org.tensorflow.demo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_alignParentTop="true" />
+
+ <RelativeLayout
+ android:id="@+id/black"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:background="#FF000000" />
+
+ <GridView
+ android:id="@+id/grid_layout"
+ android:numColumns="7"
+ android:stretchMode="columnWidth"
+ android:layout_alignParentBottom="true"
+ android:layout_width="match_parent"
+ android:layout_height="wrap_content" />
+
+ <org.tensorflow.demo.OverlayView
+ android:id="@+id/overlay"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_alignParentTop="true" />
+
+ <org.tensorflow.demo.OverlayView
+ android:id="@+id/debug_overlay"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_alignParentTop="true" />
+</RelativeLayout>
diff --git a/tensorflow/examples/android/res/values/base-strings.xml b/tensorflow/examples/android/res/values/base-strings.xml
index f6c57d5030..56edb55def 100644
--- a/tensorflow/examples/android/res/values/base-strings.xml
+++ b/tensorflow/examples/android/res/values/base-strings.xml
@@ -19,4 +19,5 @@
<string name="app_name">TensorFlow Demo</string>
<string name="activity_name_classification">TF Classify</string>
<string name="activity_name_detection">TF Detect</string>
+ <string name="activity_name_stylize">TF Stylize</string>
</resources>
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
new file mode 100644
index 0000000000..8a3c7a4ef9
--- /dev/null
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
@@ -0,0 +1,662 @@
+/*
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.tensorflow.demo;
+
+import android.content.Context;
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+import android.graphics.Bitmap.Config;
+import android.graphics.BitmapFactory;
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Matrix;
+import android.graphics.Paint;
+import android.graphics.Paint.Style;
+import android.graphics.Rect;
+import android.graphics.Typeface;
+import android.media.Image;
+import android.media.Image.Plane;
+import android.media.ImageReader;
+import android.media.ImageReader.OnImageAvailableListener;
+import android.os.Bundle;
+import android.os.SystemClock;
+import android.os.Trace;
+import android.util.Size;
+import android.util.TypedValue;
+import android.view.Display;
+import android.view.MotionEvent;
+import android.view.View;
+import android.view.View.OnClickListener;
+import android.view.View.OnTouchListener;
+import android.view.ViewGroup;
+import android.widget.BaseAdapter;
+import android.widget.Button;
+import android.widget.GridView;
+import android.widget.ImageView;
+import android.widget.Toast;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Vector;
+import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
+import org.tensorflow.demo.OverlayView.DrawCallback;
+import org.tensorflow.demo.env.BorderedText;
+import org.tensorflow.demo.env.ImageUtils;
+import org.tensorflow.demo.env.Logger;
+import org.tensorflow.demo.R;
+
+/**
+ * Sample activity that stylizes the camera preview according to "A Learned Representation For
+ * Artistic Style" (https://arxiv.org/abs/1610.07629)
+ */
+public class StylizeActivity extends CameraActivity implements OnImageAvailableListener {
+ static {
+ System.loadLibrary("tensorflow_demo");
+ }
+
+ private static final Logger LOGGER = new Logger();
+
+ private static final String MODEL_FILE = "file:///android_asset/stylize_quantized.pb";
+ private static final String INPUT_NODE = "input:0";
+ private static final String STYLE_NODE = "style_num:0";
+ private static final String OUTPUT_NODE = "transformer/expand/conv3/conv/Sigmoid";
+ private static final int NUM_STYLES = 26;
+
+ private static final boolean SAVE_PREVIEW_BITMAP = false;
+
+ // Whether to actively manipulate non-selected sliders so that sum of activations always appears
+ // to be 1.0. The actual style input tensor will be normalized to sum to 1.0 regardless.
+ private static final boolean NORMALIZE_SLIDERS = true;
+
+ private static final float TEXT_SIZE_DIP = 12;
+
+ private static final boolean DEBUG_MODEL = false;
+
+ private static final int[] SIZES = {32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024};
+
+ // Start at a medium size, but let the user step up through smaller sizes so they don't get
+ // immediately stuck processing a large image.
+ private int desiredSizeIndex = -1;
+ private int desiredSize = 256;
+ private int initializedSize = 0;
+
+ private Integer sensorOrientation;
+
+ private int previewWidth = 0;
+ private int previewHeight = 0;
+ private byte[][] yuvBytes;
+ private int[] rgbBytes = null;
+ private Bitmap rgbFrameBitmap = null;
+ private Bitmap croppedBitmap = null;
+
+ private final float[] styleVals = new float[NUM_STYLES];
+ private int[] intValues;
+ private float[] floatValues;
+
+ private int frameNum = 0;
+
+ private Bitmap cropCopyBitmap;
+ private Bitmap textureCopyBitmap;
+
+ private boolean computing = false;
+
+ private Matrix frameToCropTransform;
+ private Matrix cropToFrameTransform;
+
+ private BorderedText borderedText;
+
+ private long lastProcessingTimeMs;
+
+ private TensorFlowInferenceInterface inferenceInterface;
+
+ private int lastOtherStyle = 1;
+
+ private boolean allZero = false;
+
+ private ImageGridAdapter adapter;
+ private GridView grid;
+
+ private final OnTouchListener gridTouchAdapter =
+ new OnTouchListener() {
+ ImageSlider slider = null;
+
+ @Override
+ public boolean onTouch(final View v, final MotionEvent event) {
+ switch (event.getActionMasked()) {
+ case MotionEvent.ACTION_DOWN:
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ final ImageSlider child = adapter.items[i];
+ final Rect rect = new Rect();
+ child.getHitRect(rect);
+ if (rect.contains((int) event.getX(), (int) event.getY())) {
+ slider = child;
+ slider.setHilighted(true);
+ }
+ }
+ break;
+
+ case MotionEvent.ACTION_MOVE:
+ if (slider != null) {
+ final Rect rect = new Rect();
+ slider.getHitRect(rect);
+
+ final float newSliderVal =
+ (float)
+ Math.min(
+ 1.0,
+ Math.max(
+ 0.0, 1.0 - (event.getY() - slider.getTop()) / slider.getHeight()));
+
+ setStyle(slider, newSliderVal);
+ }
+ break;
+
+ case MotionEvent.ACTION_UP:
+ if (slider != null) {
+ slider.setHilighted(false);
+ slider = null;
+ }
+ break;
+ }
+ return true;
+ }
+ };
+
+ @Override
+ public void onCreate(final Bundle savedInstanceState) {
+ super.onCreate(savedInstanceState);
+ }
+
+ @Override
+ protected int getLayoutId() {
+ return R.layout.camera_connection_fragment_stylize;
+ }
+
+ @Override
+ protected int getDesiredPreviewFrameSize() {
+ return SIZES[SIZES.length - 1];
+ }
+
+ public static Bitmap getBitmapFromAsset(final Context context, final String filePath) {
+ final AssetManager assetManager = context.getAssets();
+
+ Bitmap bitmap = null;
+ try {
+ final InputStream inputStream = assetManager.open(filePath);
+ bitmap = BitmapFactory.decodeStream(inputStream);
+ } catch (final IOException e) {
+ LOGGER.e("Error opening bitmap!", e);
+ }
+
+ return bitmap;
+ }
+
+ private class ImageSlider extends ImageView {
+ private float value = 0.0f;
+ private boolean hilighted = false;
+
+ private final Paint boxPaint;
+ private final Paint linePaint;
+
+ public ImageSlider(final Context context) {
+ super(context);
+ value = 0.0f;
+
+ boxPaint = new Paint();
+ boxPaint.setColor(Color.BLACK);
+ boxPaint.setAlpha(128);
+
+ linePaint = new Paint();
+ linePaint.setColor(Color.WHITE);
+ linePaint.setStrokeWidth(10.0f);
+ linePaint.setStyle(Style.STROKE);
+ }
+
+ @Override
+ public void onDraw(final Canvas canvas) {
+ super.onDraw(canvas);
+ final float y = (1.0f - value) * canvas.getHeight();
+
+ // If all sliders are zero, don't bother shading anything.
+ if (!allZero) {
+ canvas.drawRect(0, 0, canvas.getWidth(), y, boxPaint);
+ }
+
+ if (value > 0.0f) {
+ canvas.drawLine(0, y, canvas.getWidth(), y, linePaint);
+ }
+
+ if (hilighted) {
+ canvas.drawRect(0, 0, getWidth(), getHeight(), linePaint);
+ }
+ }
+
+ @Override
+ protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
+ }
+
+ public void setValue(final float value) {
+ this.value = value;
+ postInvalidate();
+ }
+
+ public void setHilighted(final boolean highlighted) {
+ this.hilighted = highlighted;
+ this.postInvalidate();
+ }
+ }
+
+ private class ImageGridAdapter extends BaseAdapter {
+ final ImageSlider[] items = new ImageSlider[NUM_STYLES];
+ final ArrayList<Button> buttons = new ArrayList<Button>();
+
+ {
+ final Button sizeButton =
+ new Button(StylizeActivity.this) {
+ @Override
+ protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
+ }
+ };
+ sizeButton.setText("" + desiredSize);
+ sizeButton.setOnClickListener(
+ new OnClickListener() {
+ @Override
+ public void onClick(final View v) {
+ desiredSizeIndex = (desiredSizeIndex + 1) % SIZES.length;
+ desiredSize = SIZES[desiredSizeIndex];
+ sizeButton.setText("" + desiredSize);
+ sizeButton.postInvalidate();
+ }
+ });
+
+ final Button saveButton =
+ new Button(StylizeActivity.this) {
+ @Override
+ protected void onMeasure(final int widthMeasureSpec, final int heightMeasureSpec) {
+ super.onMeasure(widthMeasureSpec, heightMeasureSpec);
+ setMeasuredDimension(getMeasuredWidth(), getMeasuredWidth());
+ }
+ };
+ saveButton.setText("Save");
+ saveButton.setOnClickListener(
+ new OnClickListener() {
+ @Override
+ public void onClick(final View v) {
+ if (textureCopyBitmap != null) {
+ // TODO(andrewharp): Save as jpeg with guaranteed unique filename.
+ ImageUtils.saveBitmap(textureCopyBitmap, "stylized" + frameNum + ".png");
+ Toast.makeText(
+ StylizeActivity.this,
+ "Saved image to: /sdcard/tensorflow/" + "stylized" + frameNum + ".png",
+ Toast.LENGTH_LONG)
+ .show();
+ }
+ }
+ });
+
+ buttons.add(sizeButton);
+ buttons.add(saveButton);
+
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ LOGGER.v("Creating item %d", i);
+
+ if (items[i] == null) {
+ final ImageSlider slider = new ImageSlider(StylizeActivity.this);
+ final Bitmap bm =
+ getBitmapFromAsset(StylizeActivity.this, "thumbnails/style" + i + ".jpg");
+ slider.setImageBitmap(bm);
+
+ items[i] = slider;
+ }
+ }
+ }
+
+ @Override
+ public int getCount() {
+ return buttons.size() + NUM_STYLES;
+ }
+
+ @Override
+ public Object getItem(final int position) {
+ if (position < buttons.size()) {
+ return buttons.get(position);
+ } else {
+ return items[position - buttons.size()];
+ }
+ }
+
+ @Override
+ public long getItemId(final int position) {
+ return getItem(position).hashCode();
+ }
+
+ @Override
+ public View getView(final int position, final View convertView, final ViewGroup parent) {
+ if (convertView != null) {
+ return convertView;
+ }
+ return (View) getItem(position);
+ }
+ }
+
+ @Override
+ public void onPreviewSizeChosen(final Size size, final int rotation) {
+ final float textSizePx =
+ TypedValue.applyDimension(
+ TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
+ borderedText = new BorderedText(textSizePx);
+ borderedText.setTypeface(Typeface.MONOSPACE);
+
+ inferenceInterface = new TensorFlowInferenceInterface();
+ inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
+
+ previewWidth = size.getWidth();
+ previewHeight = size.getHeight();
+
+ final Display display = getWindowManager().getDefaultDisplay();
+ final int screenOrientation = display.getRotation();
+
+ LOGGER.i("Sensor orientation: %d, Screen orientation: %d", rotation, screenOrientation);
+
+ sensorOrientation = rotation + screenOrientation;
+
+ addCallback(
+ new DrawCallback() {
+ @Override
+ public void drawCallback(final Canvas canvas) {
+ renderDebug(canvas);
+ }
+ });
+
+ adapter = new ImageGridAdapter();
+ grid = (GridView) findViewById(R.id.grid_layout);
+ grid.setAdapter(adapter);
+ grid.setOnTouchListener(gridTouchAdapter);
+
+ setStyle(adapter.items[0], 1.0f);
+ }
+
+ private void setStyle(final ImageSlider slider, final float value) {
+ slider.setValue(value);
+
+ if (NORMALIZE_SLIDERS) {
+ // Slider vals correspond directly to the input tensor vals, and normalization is visually
+ // maintained by remanipulating non-selected sliders.
+ float otherSum = 0.0f;
+
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ if (adapter.items[i] != slider) {
+ otherSum += adapter.items[i].value;
+ }
+ }
+
+ if (otherSum > 0.0) {
+ float highestOtherVal = 0;
+ final float factor = otherSum > 0.0f ? (1.0f - value) / otherSum : 0.0f;
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ final ImageSlider child = adapter.items[i];
+ if (child == slider) {
+ continue;
+ }
+ final float newVal = child.value * factor;
+ child.setValue(newVal > 0.01f ? newVal : 0.0f);
+
+ if (child.value > highestOtherVal) {
+ lastOtherStyle = i;
+ highestOtherVal = child.value;
+ }
+ }
+ } else {
+ // Everything else is 0, so just pick a suitable slider to push up when the
+ // selected one goes down.
+ if (adapter.items[lastOtherStyle] == slider) {
+ lastOtherStyle = lastOtherStyle + 1 % NUM_STYLES;
+ }
+ adapter.items[lastOtherStyle].setValue(1.0f - value);
+ }
+ }
+
+ final boolean lastAllZero = allZero;
+ float sum = 0.0f;
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ sum += adapter.items[i].value;
+ }
+ allZero = sum == 0.0f;
+
+ // Now update the values used for the input tensor. If nothing is set, mix in everything
+ // equally. Otherwise everything is normalized to sum to 1.0.
+ for (int i = 0; i < NUM_STYLES; ++i) {
+ styleVals[i] = allZero ? 1.0f / NUM_STYLES : adapter.items[i].value / sum;
+
+ if (lastAllZero != allZero) {
+ adapter.items[i].postInvalidate();
+ }
+ }
+ }
+
+ @Override
+ public void onImageAvailable(final ImageReader reader) {
+ Image image = null;
+
+ try {
+ image = reader.acquireLatestImage();
+
+ if (image == null) {
+ return;
+ }
+
+ if (computing) {
+ image.close();
+ return;
+ }
+
+ if (desiredSize != initializedSize) {
+ LOGGER.i(
+ "Initializing at size preview size %dx%d, stylize size %d",
+ previewWidth, previewHeight, desiredSize);
+ rgbBytes = new int[previewWidth * previewHeight];
+ rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
+ croppedBitmap = Bitmap.createBitmap(desiredSize, desiredSize, Config.ARGB_8888);
+
+ frameToCropTransform =
+ ImageUtils.getTransformationMatrix(
+ previewWidth, previewHeight,
+ desiredSize, desiredSize,
+ sensorOrientation, true);
+
+ cropToFrameTransform = new Matrix();
+ frameToCropTransform.invert(cropToFrameTransform);
+
+ yuvBytes = new byte[3][];
+
+ intValues = new int[desiredSize * desiredSize];
+ floatValues = new float[desiredSize * desiredSize * 3];
+ initializedSize = desiredSize;
+ }
+
+ computing = true;
+
+ Trace.beginSection("imageAvailable");
+
+ final Plane[] planes = image.getPlanes();
+ fillBytes(planes, yuvBytes);
+
+ final int yRowStride = planes[0].getRowStride();
+ final int uvRowStride = planes[1].getRowStride();
+ final int uvPixelStride = planes[1].getPixelStride();
+ ImageUtils.convertYUV420ToARGB8888(
+ yuvBytes[0],
+ yuvBytes[1],
+ yuvBytes[2],
+ rgbBytes,
+ previewWidth,
+ previewHeight,
+ yRowStride,
+ uvRowStride,
+ uvPixelStride,
+ false);
+
+ image.close();
+ } catch (final Exception e) {
+ if (image != null) {
+ image.close();
+ }
+ LOGGER.e(e, "Exception!");
+ Trace.endSection();
+ return;
+ }
+
+ rgbFrameBitmap.setPixels(rgbBytes, 0, previewWidth, 0, 0, previewWidth, previewHeight);
+ final Canvas canvas = new Canvas(croppedBitmap);
+ canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
+
+ // For examining the actual TF input.
+ if (SAVE_PREVIEW_BITMAP) {
+ ImageUtils.saveBitmap(croppedBitmap);
+ }
+
+ runInBackground(
+ new Runnable() {
+ @Override
+ public void run() {
+ cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
+
+ final long startTime = SystemClock.uptimeMillis();
+ stylizeImage(croppedBitmap);
+ lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;
+
+ textureCopyBitmap = Bitmap.createBitmap(croppedBitmap);
+
+ requestRender();
+ computing = false;
+ }
+ });
+
+ Trace.endSection();
+ }
+
+ String outputNode = "";
+
+ private void stylizeImage(final Bitmap bitmap) {
+ ++frameNum;
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+
+ if (DEBUG_MODEL) {
+ // Create a white square that steps through a black background 1 pixel per frame.
+ final int centerX = (frameNum + bitmap.getWidth() / 2) % bitmap.getWidth();
+ final int centerY = bitmap.getHeight() / 2;
+ final int squareSize = 10;
+ for (int i = 0; i < intValues.length; ++i) {
+ final int x = i % bitmap.getWidth();
+ final int y = i / bitmap.getHeight();
+ final float val =
+ Math.abs(x - centerX) < squareSize && Math.abs(y - centerY) < squareSize ? 1.0f : 0.0f;
+ floatValues[i * 3] = val;
+ floatValues[i * 3 + 1] = val;
+ floatValues[i * 3 + 2] = val;
+ }
+ } else {
+ for (int i = 0; i < intValues.length; ++i) {
+ final int val = intValues[i];
+ floatValues[i * 3] = ((val >> 16) & 0xFF) / 255.0f;
+ floatValues[i * 3 + 1] = ((val >> 8) & 0xFF) / 255.0f;
+ floatValues[i * 3 + 2] = (val & 0xFF) / 255.0f;
+ }
+ }
+
+ // Copy the input data into TensorFlow.
+ inferenceInterface.fillNodeFloat(
+ INPUT_NODE, new int[] {1, bitmap.getWidth(), bitmap.getHeight(), 3}, floatValues);
+ inferenceInterface.fillNodeFloat(STYLE_NODE, new int[] {NUM_STYLES}, styleVals);
+
+ inferenceInterface.runInference(new String[] {OUTPUT_NODE});
+ inferenceInterface.readNodeFloat(OUTPUT_NODE, floatValues);
+
+ for (int i = 0; i < intValues.length; ++i) {
+ intValues[i] =
+ 0xFF000000
+ | (((int) (floatValues[i * 3] * 255)) << 16)
+ | (((int) (floatValues[i * 3 + 1] * 255)) << 8)
+ | ((int) (floatValues[i * 3 + 2] * 255));
+ }
+
+ bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ }
+
+ @Override
+ public void onSetDebug(final boolean debug) {
+ inferenceInterface.enableStatLogging(debug);
+ }
+
+ private void renderDebug(final Canvas canvas) {
+ // TODO(andrewharp): move result display to its own View instead of using debug overlay.
+ final Bitmap texture = textureCopyBitmap;
+ if (texture != null) {
+ final Matrix matrix = new Matrix();
+ final float scaleFactor =
+ DEBUG_MODEL
+ ? 4.0f
+ : Math.min(
+ (float) canvas.getWidth() / texture.getWidth(),
+ (float) canvas.getHeight() / texture.getHeight());
+ matrix.postScale(scaleFactor, scaleFactor);
+ canvas.drawBitmap(texture, matrix, new Paint());
+ }
+
+ if (!isDebug()) {
+ return;
+ }
+
+ final Bitmap copy = cropCopyBitmap;
+ if (copy == null) {
+ return;
+ }
+
+ canvas.drawColor(0x55000000);
+
+ final Matrix matrix = new Matrix();
+ final float scaleFactor = 2;
+ matrix.postScale(scaleFactor, scaleFactor);
+ matrix.postTranslate(
+ canvas.getWidth() - copy.getWidth() * scaleFactor,
+ canvas.getHeight() - copy.getHeight() * scaleFactor);
+ canvas.drawBitmap(copy, matrix, new Paint());
+
+ final Vector<String> lines = new Vector<String>();
+
+ final String[] statLines = inferenceInterface.getStatString().split("\n");
+ for (final String line : statLines) {
+ lines.add(line);
+ }
+
+ lines.add("");
+
+ lines.add("Frame: " + previewWidth + "x" + previewHeight);
+ lines.add("Crop: " + copy.getWidth() + "x" + copy.getHeight());
+ lines.add("View: " + canvas.getWidth() + "x" + canvas.getHeight());
+ lines.add("Rotation: " + sensorOrientation);
+ lines.add("Inference time: " + lastProcessingTimeMs + "ms");
+ lines.add("Desired size: " + desiredSize);
+ lines.add("Initialized size: " + initializedSize);
+
+ borderedText.drawLines(canvas, 10, canvas.getHeight() - 10, lines);
+ }
+}