aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/android
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2017-03-23 14:20:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-23 15:39:26 -0700
commit2de2d52164e2b3d0ddb69b7c1037452419de25ce (patch)
tree01db062c17fe85bcecbd560faa12ec89b58c204e /tensorflow/contrib/android
parentdf70d61ea2a76fc226237ef663e94245fb523376 (diff)
Android: update TensorFlowInferenceInterface: replace error code returns with exceptions, use longs for dimensions, and simplify method signatures where possible.
Note this is a breaking change and will cause the TF demo in tensorflow/examples/android to not work with older jars/.so binaries built from tensorflow/contrib/android Change: 151061702
Diffstat (limited to 'tensorflow/contrib/android')
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java227
1 files changed, 92 insertions, 135 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 713f63b806..1f180429b2 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.List;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
+import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
@@ -45,7 +46,13 @@ public class TensorFlowInferenceInterface {
private static final String TAG = "TensorFlowInferenceInterface";
private static final String ASSET_FILE_PREFIX = "file:///android_asset/";
- public TensorFlowInferenceInterface() {
+ /*
+ * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
+ *
+ * @param assetManager The AssetManager to use to load the model file.
+ * @param model The filepath to the GraphDef proto representing the model.
+ */
+ public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
try {
// Hack to see if the native libraries have been loaded.
@@ -63,16 +70,12 @@ public class TensorFlowInferenceInterface {
+ " libraries are present in the APK.");
}
}
- }
- /**
- * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file.
- *
- * @param assetManager The AssetManager to use to load the model file.
- * @param model The filepath to the GraphDef proto representing the model.
- * @return 0 on success.
- */
- public int initializeTensorFlow(AssetManager assetManager, String model) {
+ this.modelName = model;
+ this.g = new Graph();
+ this.sess = new Session(g);
+ this.runner = sess.runner();
+
final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX);
InputStream is = null;
try {
@@ -80,37 +83,42 @@ public class TensorFlowInferenceInterface {
is = assetManager.open(aname);
} catch (IOException e) {
if (hasAssetPrefix) {
- Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString());
- return 1;
+ throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
// Perhaps the model file is not an asset but is on disk.
try {
is = new FileInputStream(model);
} catch (IOException e2) {
- Log.e(TAG, "Failed to load model from '" + model + "': " + e2.toString());
- return 1;
+ throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
try {
- load(is);
+ loadGraph(is, g);
is.close();
Log.i(TAG, "Successfully loaded model from '" + model + "'");
- return 0;
} catch (IOException e) {
- Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString());
- return 1;
+ throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
/**
- * Runs inference between the previously registered input nodes (via fillNode*) and the requested
- * output nodes. Output nodes can then be queried with the readNode* methods.
+ * Runs inference between the previously registered input nodes (via feed*) and the requested
+ * output nodes. Output nodes can then be queried with the fetch* methods.
*
* @param outputNames A list of output nodes which should be filled by the inference pass.
- * @return 0 on success.
*/
- public int runInference(String[] outputNames) {
- // Release any Tensors from the previous runInference calls.
+ public void run(String[] outputNames) {
+ run(outputNames, false);
+ }
+
+ /**
+ * Runs inference between the previously registered input nodes (via feed*) and the requested
+ * output nodes. Output nodes can then be queried with the fetch* methods.
+ *
+ * @param outputNames A list of output nodes which should be filled by the inference pass.
+ */
+ public void run(String[] outputNames, boolean enableStats) {
+ // Release any Tensors from the previous run calls.
closeFetches();
// Add fetches.
@@ -125,6 +133,10 @@ public class TensorFlowInferenceInterface {
if (enableStats) {
Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
fetchTensors = r.outputs;
+
+ if (runStats == null) {
+ runStats = new RunStats();
+ }
runStats.add(r.metadata);
} else {
fetchTensors = runner.run();
@@ -139,16 +151,13 @@ public class TensorFlowInferenceInterface {
+ "], outputs:["
+ TextUtils.join(", ", fetchNames)
+ "]");
- Log.e(TAG, "Inference exception: " + e.toString());
- return -1;
+ throw e;
} finally {
- // Always release the feeds (to save resources) and reset the runner, this runInference is
+ // Always release the feeds (to save resources) and reset the runner, this run is
// over.
closeFeeds();
runner = sess.runner();
}
-
- return 0;
}
/** Returns a reference to the Graph describing the computation run during inference. */
@@ -156,15 +165,13 @@ public class TensorFlowInferenceInterface {
return g;
}
- /**
- * Whether to collect stats during inference. This should only be enabled when needed, as it will
- * add overhead.
- */
- public void enableStatLogging(boolean enabled) {
- enableStats = enabled;
- if (enableStats && runStats == null) {
- runStats = new RunStats();
+ public Operation graphOperation(String operationName) {
+ final Operation operation = g.operation(operationName);
+ if (operation == null) {
+ throw new RuntimeException(
+ "Node '" + operationName + "' does not exist in model '" + modelName + "'");
}
+ return operation;
}
/** Returns the last stat summary string if logging is enabled. */
@@ -185,7 +192,15 @@ public class TensorFlowInferenceInterface {
runStats.close();
}
runStats = null;
- enableStats = false;
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
}
// Methods for taking a native Tensor and filling it with values from Java arrays.
@@ -196,8 +211,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFloat(String inputName, int[] dims, float[] src) {
- addFeed(inputName, Tensor.create(mkDims(dims), FloatBuffer.wrap(src)));
+ public void feed(String inputName, float[] src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src)));
}
/**
@@ -206,8 +221,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeInt(String inputName, int[] dims, int[] src) {
- addFeed(inputName, Tensor.create(mkDims(dims), IntBuffer.wrap(src)));
+ public void feed(String inputName, int[] src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src)));
}
/**
@@ -216,8 +231,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeDouble(String inputName, int[] dims, double[] src) {
- addFeed(inputName, Tensor.create(mkDims(dims), DoubleBuffer.wrap(src)));
+ public void feed(String inputName, double[] src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src)));
}
/**
@@ -226,8 +241,8 @@ public class TensorFlowInferenceInterface {
* as many elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeByte(String inputName, int[] dims, byte[] src) {
- addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), ByteBuffer.wrap(src)));
+ public void feed(String inputName, byte[] src, long... dims) {
+ addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src)));
}
// Methods for taking a native Tensor and filling it with src from Java native IO buffers.
@@ -239,8 +254,8 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromFloatBuffer(String inputName, IntBuffer dims, FloatBuffer src) {
- addFeed(inputName, Tensor.create(mkDims(dims), src));
+ public void feed(String inputName, FloatBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, src));
}
/**
@@ -250,8 +265,8 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromIntBuffer(String inputName, IntBuffer dims, IntBuffer src) {
- addFeed(inputName, Tensor.create(mkDims(dims), src));
+ public void feed(String inputName, IntBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, src));
}
/**
@@ -261,8 +276,8 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromDoubleBuffer(String inputName, IntBuffer dims, DoubleBuffer src) {
- addFeed(inputName, Tensor.create(mkDims(dims), src));
+ public void feed(String inputName, DoubleBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(dims, src));
}
/**
@@ -272,52 +287,44 @@ public class TensorFlowInferenceInterface {
* elements as that of the destination Tensor. If {@link src} has more elements than the
* destination has capacity, the copy is truncated.
*/
- public void fillNodeFromByteBuffer(String inputName, IntBuffer dims, ByteBuffer src) {
- addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), src));
+ public void feed(String inputName, ByteBuffer src, long... dims) {
+ addFeed(inputName, Tensor.create(DataType.UINT8, dims, src));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeFloat(String outputName, float[] dst) {
- return readNodeIntoFloatBuffer(outputName, FloatBuffer.wrap(dst));
+ public void fetch(String outputName, float[] dst) {
+ fetch(outputName, FloatBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeInt(String outputName, int[] dst) {
- return readNodeIntoIntBuffer(outputName, IntBuffer.wrap(dst));
+ public void fetch(String outputName, int[] dst) {
+ fetch(outputName, IntBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeDouble(String outputName, double[] dst) {
- return readNodeIntoDoubleBuffer(outputName, DoubleBuffer.wrap(dst));
+ public void fetch(String outputName, double[] dst) {
+ fetch(outputName, DoubleBuffer.wrap(dst));
}
/**
* Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link
* dst} must have length greater than or equal to that of the source Tensor. This operation will
* not affect dst's content past the source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeByte(String outputName, byte[] dst) {
- return readNodeIntoByteBuffer(outputName, ByteBuffer.wrap(dst));
+ public void fetch(String outputName, byte[] dst) {
+ fetch(outputName, ByteBuffer.wrap(dst));
}
/**
@@ -325,16 +332,9 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoFloatBuffer(String outputName, FloatBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, FloatBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
/**
@@ -342,16 +342,9 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoIntBuffer(String outputName, IntBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, IntBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
/**
@@ -359,16 +352,9 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoDoubleBuffer(String outputName, DoubleBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, DoubleBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
/**
@@ -376,22 +362,12 @@ public class TensorFlowInferenceInterface {
* <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than
* or equal to that of the source Tensor. This operation will not affect dst's content past the
* source Tensor's size.
- *
- * @return 0 on success, -1 on failure.
*/
- public int readNodeIntoByteBuffer(String outputName, ByteBuffer dst) {
- Tensor t = getTensor(outputName);
- if (t == null) {
- return -1;
- }
- t.writeTo(dst);
- return 0;
+ public void fetch(String outputName, ByteBuffer dst) {
+ getTensor(outputName).writeTo(dst);
}
- private void load(InputStream is) throws IOException {
- this.g = new Graph();
- this.sess = new Session(g);
- this.runner = sess.runner();
+ private void loadGraph(InputStream is, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();
Trace.beginSection("initializeTensorFlow");
@@ -425,26 +401,6 @@ public class TensorFlowInferenceInterface {
"Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version());
}
- // The TensorFlowInferenceInterface API used int[] for dims, but the underlying TensorFlow runtime
- // allows for 64-bit dimension sizes, so it needs to be converted to a long[]
- private long[] mkDims(int[] dims) {
- long[] ret = new long[dims.length];
- for (int i = 0; i < dims.length; ++i) {
- ret[i] = (long) dims[i];
- }
- return ret;
- }
-
- // Similar to mkDims(int[]), with the shape provided in an IntBuffer.
- private long[] mkDims(IntBuffer dims) {
- if (dims.hasArray()) {
- return mkDims(dims.array());
- }
- int[] copy = new int[dims.remaining()];
- dims.duplicate().get(copy);
- return mkDims(copy);
- }
-
private void addFeed(String inputName, Tensor t) {
// The string format accepted by TensorFlowInferenceInterface is node_name[:output_index].
TensorId tid = TensorId.parse(inputName);
@@ -485,9 +441,10 @@ public class TensorFlowInferenceInterface {
if (n.equals(outputName)) {
return fetchTensors.get(i);
}
- i++;
+ ++i;
}
- return null;
+ throw new RuntimeException(
+ "Node '" + outputName + "' was not provided to run(), so it cannot be read");
}
private void closeFeeds() {
@@ -507,10 +464,11 @@ public class TensorFlowInferenceInterface {
}
// State immutable between initializeTensorFlow calls.
- private Graph g;
- private Session sess;
+ private final String modelName;
+ private final Graph g;
+ private final Session sess;
- // State reset on every call to runInference.
+ // State reset on every call to run.
private Session.Runner runner;
private List<String> feedNames = new ArrayList<String>();
private List<Tensor> feedTensors = new ArrayList<Tensor>();
@@ -518,6 +476,5 @@ public class TensorFlowInferenceInterface {
private List<Tensor> fetchTensors = new ArrayList<Tensor>();
// Mutable state.
- private boolean enableStats;
private RunStats runStats;
}