aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2017-03-23 14:20:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-23 15:39:26 -0700
commit2de2d52164e2b3d0ddb69b7c1037452419de25ce (patch)
tree01db062c17fe85bcecbd560faa12ec89b58c204e
parentdf70d61ea2a76fc226237ef663e94245fb523376 (diff)
Android: update TensorFlowInferenceInterface: replace error code returns with exceptions, use longs for dimensions, and simplify method signatures where possible.
Note this is a breaking change and will cause the TF demo in tensorflow/examples/android to not work with older jars/.so binaries built from tensorflow/contrib/android Change: 151061702
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java227
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java18
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java31
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java26
-rw-r--r--tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java29
5 files changed, 134 insertions, 197 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 713f63b806..1f180429b2 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
+import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
@@ -45,7 +46,13 @@ public class TensorFlowInferenceInterface {
private static final String TAG = "TensorFlowInferenceInterface";
private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
- public TensorFlowInferenceInterface() {
+ /*
+ * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
+ *
+ * @param assetManager The AssetManager to use to load the model file.
+ * @param model The filepath to the GraphDef proto representing the model.
+ */
+ public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
try {
// Hack to see if the native libraries have been loaded.
@@ -63,16 +70,12 @@ public class TensorFlowInferenceInterface {
+ " libraries are present in the APK.");
}
}
- }
- /**
- * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
- *
- * @param assetManager The AssetManager to use to load the model file.
- * @param model The filepath to the GraphDef proto representing the model.
- * @return 0 on success.
- */
- public int initializeTensorFlow(AssetManager assetManager, String model) {
+ this.modelName = model;
+ this.g = new Graph();
+ this.sess = new Session(g);
+ this.runner = sess.runner();
+
final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
InputStream is = null;
try {
@@ -80,37 +83,42 @@ public class TensorFlowInferenceInterface {
is = assetManager.open(aname);
} catch (IOException e) {
if (hasAssetPrefix) {
- Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString());
- return 1;
+ throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
// Perhaps the model file is not an asset but is on disk.
try {
is = new FileInputStream(model);
} catch (IOException e2) {
- Log.e(TAG, "Failed to load model from '" + model + "': " + e2.toString());
- return 1;
+ throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
try {
- load(is);
+ loadGraph(is, g);
is.close();
Log.i(TAG, "Successfully loaded model from '" + model + "'");
- return 0;
} catch (IOException e) {
- Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString());
- return 1;
+ throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
/**
- * Runs inference between the previously registered input nodes (via fillNode*) and the requested
- * output nodes. Output nodes can then be queried with the readNode* methods.
+ * Runs inference between the previously registered input nodes (via feed*) and the requested
+ * output nodes. Output nodes can then be queried with the fetch* methods.
*
* @param outputNames A list of output nodes which should be filled by the inference pass.
- * @return 0 on success.
*/
- public int runInference(String[] outputNames) {
- // Release any Tensors from the previous runInference calls.
+ public void run(String[] outputNames) {
+ run(outputNames, false);
+ }
+
+ /**
+ * Runs inference between the previously registered input nodes (via feed*) and the requested
+ * output nodes. Output nodes can then be queried with the fetch* methods.
+ *
+ * @param outputNames A list of output nodes which should be filled by the inference pass.
+ */
+ public void run(String[] outputNames, boolean enableStats) {
+ // Release any Tensors from the previous run calls.
closeFetches();
// Add fetches.
@@ -125,6 +133,10 @@ public class TensorFlowInferenceInterface {
if (enableStats) {
Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
fetchTensors = r.outputs;
+
+ if (runStats == null) {
+ runStats = new RunStats();
+ }
runStats.add(r.metadata);
} else {
fetchTensors = runner.run();
@@ -139,16 +151,13 @@ public class TensorFlowInferenceInterface {
+ "], outputs:["
+ TextUtils.join(", ", fetchNames)
+ "]");
- Log.e(TAG, "Inference exception: " + e.toString());
- return -1;
+ throw e;
} finally {
- // Always release the feeds (to save resources) and reset the runner, this runInference is
+ // Always release the feeds (to save resources) and reset the runner, this run is
// over.
closeFeeds();
runner = sess.runner();
}
-
- return 0;
}
/** Returns a reference to the Graph describing the computation run during inference. */
@@ -156,15 +165,13 @@ public class TensorFlowInferenceInterface {
return g;
}
- /**
- * Whether to collect stats during inference. This should only be enabled when needed, as it will
- * add overhead.
- */
- public void enableStatLogging(boolean enabled) {
- enableStats = enabled;
- if (enableStats && runStats == null) {
- runStats = new RunStats();
+ public Operation graphOperation(String operationName) {
+ final Operation operation = g.operation(operationName);
+ if (operation == null) {
+ throw new RuntimeException(
+ "Node '" + operationName + "' does not exist in model '" + modelName + "'");
}
+ return operation;
}
/** Returns the last stat summary string if logging is enabled. */
@@ -185,7 +192,15 @@ public class TensorFlowInferenceInterface {
runStats.close();
}
runStats = null;
- enableStats = false;
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
}
// Methods for taking a native Tensor and filling it with values from Java arrays.
@@ -196,8 +211,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFloat(String inputName, int[] dims, float[] src) {
- addFeed(inputName, Tensor.create(mkDims(dims), FloatBuffer.wrap(src)));
+ public void feed(String inputName, float[] src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
}
/**
@@ -206,8 +221,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeInt(String inputName, int[] dims, int[] src) {
- addFeed(inputName, Tensor.create(mkDims(dims), IntBuffer.wrap(src)));
+ public void feed(String inputName, int[] src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
}
/**
@@ -216,8 +231,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeDouble(String inputName, int[] dims, double[] src) {
- addFeed(inputName, Tensor.create(mkDims(dims), DoubleBuffer.wrap(src)));
+ public void feed(String inputName, double[] src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
}
/**
@@ -226,8 +241,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeByte(String inputName, int[] dims, byte[] src) {
- addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), ByteBuffer.wrap(src)));
+ public void feed(String inputName, byte[] src, long... dims) {
+ addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src)));
}
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
@@ -239,8 +254,8 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromFloatBuffer(String inputName, IntBuffer dims, FloatBuffer src) {
- addFeed(inputName, Tensor.create(mkDims(dims), src));
+ public void feed(String inputName, FloatBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, src));
}
/**
@@ -250,8 +265,8 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromIntBuffer(String inputName, IntBuffer dims, IntBuffer src) {
- addFeed(inputName, Tensor.create(mkDims(dims), src));
+ public void feed(String inputName, IntBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, src));
}
/**
@@ -261,8 +276,8 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromDoubleBuffer(String inputName, IntBuffer dims, DoubleBuffer src) {
- addFeed(inputName, Tensor.create(mkDims(dims), src));
+ public void feed(String inputName, DoubleBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, src));
}
/**
@@ -272,52 +287,44 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromByteBuffer(String inputName, IntBuffer dims, ByteBuffer src) {
- addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), src));
+ public void feed(String inputName, ByteBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(DataType.UINT8, dims, src));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeFloat(String outputName, float[] dst) {
- return readNodeIntoFloatBuffer(outputName, FloatBuffer.wrap(dst));
+ public void fetch(String outputName, float[] dst) {
+ fetch(outputName, FloatBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeInt(String outputName, int[] dst) {
- return readNodeIntoIntBuffer(outputName, IntBuffer.wrap(dst));
+ public void fetch(String outputName, int[] dst) {
+ fetch(outputName, IntBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeDouble(String outputName, double[] dst) {
- return readNodeIntoDoubleBuffer(outputName, DoubleBuffer.wrap(dst));
+ public void fetch(String outputName, double[] dst) {
+ fetch(outputName, DoubleBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeByte(String outputName, byte[] dst) {
- return readNodeIntoByteBuffer(outputName, ByteBuffer.wrap(dst));
+ public void fetch(String outputName, byte[] dst) {
+ fetch(outputName, ByteBuffer.wrap(dst));
}
/**
@@ -325,16 +332,9 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoFloatBuffer(String outputName, FloatBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, FloatBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
/**
@@ -342,16 +342,9 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoIntBuffer(String outputName, IntBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, IntBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
/**
@@ -359,16 +352,9 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoDoubleBuffer(String outputName, DoubleBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, DoubleBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
/**
@@ -376,22 +362,12 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoByteBuffer(String outputName, ByteBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, ByteBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
- private void load(InputStream is) throws IOException {
- this.g = new Graph();
- this.sess = new Session(g);
- this.runner = sess.runner();
+ private void loadGraph(InputStream is, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();
Trace.beginSection("initializeTensorFlow");
@@ -425,26 +401,6 @@ public class TensorFlowInferenceInterface {
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
- // The TensorFlowInferenceInterface API used int[] for dims, but the underlying TensorFlow runtime
- // allows for 64-bit dimension sizes, so it needs to be converted to a long[]
- private long[] mkDims(int[] dims) {
- long[] ret = new long[dims.length];
- for (int i = 0; i < dims.length; ++i) {
- ret[i] = (long) dims[i];
- }
- return ret;
- }
-
- // Similar to mkDims(int[]), with the shape provided in an IntBuffer.
- private long[] mkDims(IntBuffer dims) {
- if (dims.hasArray()) {
- return mkDims(dims.array());
- }
- int[] copy = new int[dims.remaining()];
- dims.duplicate().get(copy);
- return mkDims(copy);
- }
-
private void addFeed(String inputName, Tensor t) {
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
TensorId tid = TensorId.parse(inputName);
@@ -485,9 +441,10 @@ public class TensorFlowInferenceInterface {
if (n.equals(outputName)) {
return fetchTensors.get(i);
}
- i++;
+ ++i;
}
- return null;
+ throw new RuntimeException(
+ "Node '" + outputName + "' was not provided to run(), so it cannot be read");
}
private void closeFeeds() {
@@ -507,10 +464,11 @@ public class TensorFlowInferenceInterface {
}
// State immutable between initializeTensorFlow calls.
- private Graph g;
- private Session sess;
+ private final String modelName;
+ private final Graph g;
+ private final Session sess;
- // State reset on every call to runInference.
+ // State reset on every call to run.
private Session.Runner runner;
private List<String> feedNames = new ArrayList<String>();
private List<Tensor> feedTensors = new ArrayList<Tensor>();
@@ -518,6 +476,5 @@ public class TensorFlowInferenceInterface {
private List<Tensor> fetchTensors = new ArrayList<Tensor>();
// Mutable state.
- private boolean enableStats;
private RunStats runStats;
}
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
index c1a893e9ee..d0df1a4483 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java
@@ -369,8 +369,7 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
- inferenceInterface = new TensorFlowInferenceInterface();
- inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE);
+ inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
previewWidth = size.getWidth();
previewHeight = size.getHeight();
@@ -585,12 +584,12 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL
}
// Copy the input data into TensorFlow.
- inferenceInterface.fillNodeFloat(
- INPUT_NODE, new int[] {1, bitmap.getWidth(), bitmap.getHeight(), 3}, floatValues);
- inferenceInterface.fillNodeFloat(STYLE_NODE, new int[] {NUM_STYLES}, styleVals);
+ inferenceInterface.feed(
+ INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3);
+ inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES);
- inferenceInterface.runInference(new String[] {OUTPUT_NODE});
- inferenceInterface.readNodeFloat(OUTPUT_NODE, floatValues);
+ inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug());
+ inferenceInterface.fetch(OUTPUT_NODE, floatValues);
for (int i = 0; i < intValues.length; ++i) {
intValues[i] =
@@ -603,11 +602,6 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL
bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
}
- @Override
- public void onSetDebug(final boolean debug) {
- inferenceInterface.enableStatLogging(debug);
- }
-
private void renderDebug(final Canvas canvas) {
// TODO(andrewharp): move result display to its own View instead of using debug overlay.
final Bitmap texture = textureCopyBitmap;
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java
index 9f80665ee6..f660178ebe 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java
@@ -56,6 +56,8 @@ public class TensorFlowImageClassifier implements Classifier {
private float[] outputs;
private String[] outputNames;
+ private boolean logStats = false;
+
private TensorFlowInferenceInterface inferenceInterface;
private TensorFlowImageClassifier() {}
@@ -102,16 +104,10 @@ public class TensorFlowImageClassifier implements Classifier {
throw new RuntimeException("Problem reading label file!" , e);
}
- c.inferenceInterface = new TensorFlowInferenceInterface();
- if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
- throw new RuntimeException("TF initialization failed");
- }
+ c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
+
// The shape of the output is [N, NUM_CLASSES], where N is the batch size.
- final Operation operation = c.inferenceInterface.graph().operation(outputName);
- if (operation == null) {
- throw new RuntimeException("Node '" + outputName + "' does not exist in model '"
- + modelFilename + "'");
- }
+ final Operation operation = c.inferenceInterface.graphOperation(outputName);
final int numClasses = (int) operation.output(0).shape().size(1);
Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses);
@@ -149,19 +145,18 @@ public class TensorFlowImageClassifier implements Classifier {
Trace.endSection();
// Copy the input data into TensorFlow.
- Trace.beginSection("fillNodeFloat");
- inferenceInterface.fillNodeFloat(
- inputName, new int[] {1, inputSize, inputSize, 3}, floatValues);
+ Trace.beginSection("feed");
+ inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
Trace.endSection();
// Run the inference call.
- Trace.beginSection("runInference");
- inferenceInterface.runInference(outputNames);
+ Trace.beginSection("run");
+ inferenceInterface.run(outputNames, logStats);
Trace.endSection();
// Copy the output Tensor back into the output array.
- Trace.beginSection("readNodeFloat");
- inferenceInterface.readNodeFloat(outputName, outputs);
+ Trace.beginSection("fetch");
+ inferenceInterface.fetch(outputName, outputs);
Trace.endSection();
// Find the best classifications.
@@ -192,8 +187,8 @@ public class TensorFlowImageClassifier implements Classifier {
}
@Override
- public void enableStatLogging(boolean debug) {
- inferenceInterface.enableStatLogging(debug);
+ public void enableStatLogging(boolean logStats) {
+ this.logStats = logStats;
}
@Override
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
index 2ec9476ef9..f3e7114335 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java
@@ -62,6 +62,8 @@ public class TensorFlowMultiBoxDetector implements Classifier {
private String[] outputNames;
private int numLocations;
+ private boolean logStats = false;
+
private TensorFlowInferenceInterface inferenceInterface;
private float[] boxPriors;
@@ -89,10 +91,7 @@ public class TensorFlowMultiBoxDetector implements Classifier {
final String outputScoresName) {
final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector();
- d.inferenceInterface = new TensorFlowInferenceInterface();
- if (d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) {
- throw new RuntimeException("TF initialization failed");
- }
+ d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
final Graph g = d.inferenceInterface.graph();
@@ -222,22 +221,21 @@ public class TensorFlowMultiBoxDetector implements Classifier {
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
- Trace.beginSection("fillNodeFloat");
- inferenceInterface.fillNodeFloat(
- inputName, new int[] {1, inputSize, inputSize, 3}, floatValues);
+ Trace.beginSection("feed");
+ inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
Trace.endSection();
// Run the inference call.
- Trace.beginSection("runInference");
- inferenceInterface.runInference(outputNames);
+ Trace.beginSection("run");
+ inferenceInterface.run(outputNames, logStats);
Trace.endSection();
// Copy the output Tensor back into the output array.
- Trace.beginSection("readNodeFloat");
+ Trace.beginSection("fetch");
final float[] outputScoresEncoding = new float[numLocations];
final float[] outputLocationsEncoding = new float[numLocations * 4];
- inferenceInterface.readNodeFloat(outputNames[0], outputLocationsEncoding);
- inferenceInterface.readNodeFloat(outputNames[1], outputScoresEncoding);
+ inferenceInterface.fetch(outputNames[0], outputLocationsEncoding);
+ inferenceInterface.fetch(outputNames[1], outputScoresEncoding);
Trace.endSection();
outputLocations = decodeLocationsEncoding(outputLocationsEncoding);
@@ -275,8 +273,8 @@ public class TensorFlowMultiBoxDetector implements Classifier {
}
@Override
- public void enableStatLogging(final boolean debug) {
- inferenceInterface.enableStatLogging(debug);
+ public void enableStatLogging(final boolean logStats) {
+ this.logStats = logStats;
}
@Override
diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java
index 062e8cc8d0..174723071d 100644
--- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java
+++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java
@@ -86,6 +86,8 @@ public class TensorFlowYoloDetector implements Classifier {
private int blockSize;
+ private boolean logStats = false;
+
private TensorFlowInferenceInterface inferenceInterface;
/** Initializes a native TensorFlow session for classifying images. */
@@ -106,13 +108,8 @@ public class TensorFlowYoloDetector implements Classifier {
d.floatValues = new float[inputSize * inputSize * 3];
d.blockSize = blockSize;
- d.inferenceInterface = new TensorFlowInferenceInterface();
+ d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
- final int status = d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename);
- if (status != 0) {
- LOGGER.e("TF init status: " + status);
- throw new RuntimeException("TF init status (" + status + ") != 0");
- }
return d;
}
@@ -157,30 +154,26 @@ public class TensorFlowYoloDetector implements Classifier {
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
- Trace.beginSection("fillNodeFloat");
- inferenceInterface.fillNodeFloat(
- inputName, new int[] {1, inputSize, inputSize, 3}, floatValues);
+ Trace.beginSection("feed");
+ inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3);
Trace.endSection();
timer.endSplit("ready for inference");
// Run the inference call.
- Trace.beginSection("runInference");
- final int resultCode = inferenceInterface.runInference(outputNames);
- if (resultCode != 0) {
- throw new RuntimeException("Bad result code from inference: " + resultCode);
- }
+ Trace.beginSection("run");
+ inferenceInterface.run(outputNames, logStats);
Trace.endSection();
timer.endSplit("ran inference");
// Copy the output Tensor back into the output array.
- Trace.beginSection("readNodeFloat");
+ Trace.beginSection("fetch");
final int gridWidth = bitmap.getWidth() / blockSize;
final int gridHeight = bitmap.getHeight() / blockSize;
final float[] output =
new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK];
- inferenceInterface.readNodeFloat(outputNames[0], output);
+ inferenceInterface.fetch(outputNames[0], output);
Trace.endSection();
// Find the best detections.
@@ -256,8 +249,8 @@ public class TensorFlowYoloDetector implements Classifier {
}
@Override
- public void enableStatLogging(final boolean debug) {
- inferenceInterface.enableStatLogging(debug);
+ public void enableStatLogging(final boolean logStats) {
+ this.logStats = logStats;
}
@Override