diff options
author | Andrew Harp <andrewharp@google.com> | 2017-03-23 14:20:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-23 15:39:26 -0700 |
commit | 2de2d52164e2b3d0ddb69b7c1037452419de25ce (patch) | |
tree | 01db062c17fe85bcecbd560faa12ec89b58c204e /tensorflow/contrib/android | |
parent | df70d61ea2a76fc226237ef663e94245fb523376 (diff) |
Android: update TensorFlowInferenceInterface: replace error code returns with exceptions, use longs for dimensions, and simplify method signatures where possible.
Note this is a breaking change and will cause the TF demo in tensorflow/examples/android to not work with older jars/.so binaries built from tensorflow/contrib/android
Change: 151061702
Diffstat (limited to 'tensorflow/contrib/android')
-rw-r--r-- | tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java | 227 |
1 files changed, 92 insertions, 135 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 713f63b806..1f180429b2 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.List; import org.tensorflow.DataType; import org.tensorflow.Graph; +import org.tensorflow.Operation; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.TensorFlow; @@ -45,7 +46,13 @@ public class TensorFlowInferenceInterface { private static final String TAG = "TensorFlowInferenceInterface"; private static final String ASSET_FILE_PREFIX = "file:///android_asset/"; - public TensorFlowInferenceInterface() { + /* + * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file. + * + * @param assetManager The AssetManager to use to load the model file. + * @param model The filepath to the GraphDef proto representing the model. + */ + public TensorFlowInferenceInterface(AssetManager assetManager, String model) { Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); try { // Hack to see if the native libraries have been loaded. @@ -63,16 +70,12 @@ public class TensorFlowInferenceInterface { + " libraries are present in the APK."); } } - } - /** - * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file. - * - * @param assetManager The AssetManager to use to load the model file. - * @param model The filepath to the GraphDef proto representing the model. - * @return 0 on success. - */ - public int initializeTensorFlow(AssetManager assetManager, String model) { + this.modelName = model; + this.g = new Graph(); + this.sess = new Session(g); + this.runner = sess.runner(); + final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX); InputStream is = null; try { @@ -80,37 +83,42 @@ public class TensorFlowInferenceInterface { is = assetManager.open(aname); } catch (IOException e) { if (hasAssetPrefix) { - Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString()); - return 1; + throw new RuntimeException("Failed to load model from '" + model + "'", e); } // Perhaps the model file is not an asset but is on disk. try { is = new FileInputStream(model); } catch (IOException e2) { - Log.e(TAG, "Failed to load model from '" + model + "': " + e2.toString()); - return 1; + throw new RuntimeException("Failed to load model from '" + model + "'", e); } } try { - load(is); + loadGraph(is, g); is.close(); Log.i(TAG, "Successfully loaded model from '" + model + "'"); - return 0; } catch (IOException e) { - Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString()); - return 1; + throw new RuntimeException("Failed to load model from '" + model + "'", e); } } /** - * Runs inference between the previously registered input nodes (via fillNode*) and the requested - * output nodes. Output nodes can then be queried with the readNode* methods. + * Runs inference between the previously registered input nodes (via feed*) and the requested + * output nodes. Output nodes can then be queried with the fetch* methods. * * @param outputNames A list of output nodes which should be filled by the inference pass. - * @return 0 on success. */ - public int runInference(String[] outputNames) { - // Release any Tensors from the previous runInference calls. + public void run(String[] outputNames) { + run(outputNames, false); + } + + /** + * Runs inference between the previously registered input nodes (via feed*) and the requested + * output nodes. Output nodes can then be queried with the fetch* methods. + * + * @param outputNames A list of output nodes which should be filled by the inference pass. + */ + public void run(String[] outputNames, boolean enableStats) { + // Release any Tensors from the previous run calls. closeFetches(); // Add fetches. @@ -125,6 +133,10 @@ public class TensorFlowInferenceInterface { if (enableStats) { Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata(); fetchTensors = r.outputs; + + if (runStats == null) { + runStats = new RunStats(); + } runStats.add(r.metadata); } else { fetchTensors = runner.run(); @@ -139,16 +151,13 @@ public class TensorFlowInferenceInterface { + "], outputs:[" + TextUtils.join(", ", fetchNames) + "]"); - Log.e(TAG, "Inference exception: " + e.toString()); - return -1; + throw e; } finally { - // Always release the feeds (to save resources) and reset the runner, this runInference is + // Always release the feeds (to save resources) and reset the runner, this run is // over. closeFeeds(); runner = sess.runner(); } - - return 0; } /** Returns a reference to the Graph describing the computation run during inference. */ @@ -156,15 +165,13 @@ public class TensorFlowInferenceInterface { return g; } - /** - * Whether to collect stats during inference. This should only be enabled when needed, as it will - * add overhead. - */ - public void enableStatLogging(boolean enabled) { - enableStats = enabled; - if (enableStats && runStats == null) { - runStats = new RunStats(); + public Operation graphOperation(String operationName) { + final Operation operation = g.operation(operationName); + if (operation == null) { + throw new RuntimeException( + "Node '" + operationName + "' does not exist in model '" + modelName + "'"); } + return operation; } /** Returns the last stat summary string if logging is enabled. */ @@ -185,7 +192,15 @@ public class TensorFlowInferenceInterface { runStats.close(); } runStats = null; - enableStats = false; + } + + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } } // Methods for taking a native Tensor and filling it with values from Java arrays. @@ -196,8 +211,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFloat(String inputName, int[] dims, float[] src) { - addFeed(inputName, Tensor.create(mkDims(dims), FloatBuffer.wrap(src))); + public void feed(String inputName, float[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src))); } /** @@ -206,8 +221,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeInt(String inputName, int[] dims, int[] src) { - addFeed(inputName, Tensor.create(mkDims(dims), IntBuffer.wrap(src))); + public void feed(String inputName, int[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src))); } /** @@ -216,8 +231,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeDouble(String inputName, int[] dims, double[] src) { - addFeed(inputName, Tensor.create(mkDims(dims), DoubleBuffer.wrap(src))); + public void feed(String inputName, double[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src))); } /** @@ -226,8 +241,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeByte(String inputName, int[] dims, byte[] src) { - addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), ByteBuffer.wrap(src))); + public void feed(String inputName, byte[] src, long... dims) { + addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src))); } // Methods for taking a native Tensor and filling it with src from Java native IO buffers. @@ -239,8 +254,8 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromFloatBuffer(String inputName, IntBuffer dims, FloatBuffer src) { - addFeed(inputName, Tensor.create(mkDims(dims), src)); + public void feed(String inputName, FloatBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); } /** @@ -250,8 +265,8 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromIntBuffer(String inputName, IntBuffer dims, IntBuffer src) { - addFeed(inputName, Tensor.create(mkDims(dims), src)); + public void feed(String inputName, IntBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); } /** @@ -261,8 +276,8 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromDoubleBuffer(String inputName, IntBuffer dims, DoubleBuffer src) { - addFeed(inputName, Tensor.create(mkDims(dims), src)); + public void feed(String inputName, DoubleBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); } /** @@ -272,52 +287,44 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromByteBuffer(String inputName, IntBuffer dims, ByteBuffer src) { - addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), src)); + public void feed(String inputName, ByteBuffer src, long... dims) { + addFeed(inputName, Tensor.create(DataType.UINT8, dims, src)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeFloat(String outputName, float[] dst) { - return readNodeIntoFloatBuffer(outputName, FloatBuffer.wrap(dst)); + public void fetch(String outputName, float[] dst) { + fetch(outputName, FloatBuffer.wrap(dst)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeInt(String outputName, int[] dst) { - return readNodeIntoIntBuffer(outputName, IntBuffer.wrap(dst)); + public void fetch(String outputName, int[] dst) { + fetch(outputName, IntBuffer.wrap(dst)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeDouble(String outputName, double[] dst) { - return readNodeIntoDoubleBuffer(outputName, DoubleBuffer.wrap(dst)); + public void fetch(String outputName, double[] dst) { + fetch(outputName, DoubleBuffer.wrap(dst)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeByte(String outputName, byte[] dst) { - return readNodeIntoByteBuffer(outputName, ByteBuffer.wrap(dst)); + public void fetch(String outputName, byte[] dst) { + fetch(outputName, ByteBuffer.wrap(dst)); } /** @@ -325,16 +332,9 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoFloatBuffer(String outputName, FloatBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, FloatBuffer dst) { + getTensor(outputName).writeTo(dst); } /** @@ -342,16 +342,9 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoIntBuffer(String outputName, IntBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, IntBuffer dst) { + getTensor(outputName).writeTo(dst); } /** @@ -359,16 +352,9 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoDoubleBuffer(String outputName, DoubleBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, DoubleBuffer dst) { + getTensor(outputName).writeTo(dst); } /** @@ -376,22 +362,12 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoByteBuffer(String outputName, ByteBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, ByteBuffer dst) { + getTensor(outputName).writeTo(dst); } - private void load(InputStream is) throws IOException { - this.g = new Graph(); - this.sess = new Session(g); - this.runner = sess.runner(); + private void loadGraph(InputStream is, Graph g) throws IOException { final long startMs = System.currentTimeMillis(); Trace.beginSection("initializeTensorFlow"); @@ -425,26 +401,6 @@ public class TensorFlowInferenceInterface { "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()); } - // The TensorFlowInferenceInterface API used int[] for dims, but the underlying TensorFlow runtime - // allows for 64-bit dimension sizes, so it needs to be converted to a long[] - private long[] mkDims(int[] dims) { - long[] ret = new long[dims.length]; - for (int i = 0; i < dims.length; ++i) { - ret[i] = (long) dims[i]; - } - return ret; - } - - // Similar to mkDims(int[]), with the shape provided in an IntBuffer. - private long[] mkDims(IntBuffer dims) { - if (dims.hasArray()) { - return mkDims(dims.array()); - } - int[] copy = new int[dims.remaining()]; - dims.duplicate().get(copy); - return mkDims(copy); - } - private void addFeed(String inputName, Tensor t) { // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index]. TensorId tid = TensorId.parse(inputName); @@ -485,9 +441,10 @@ public class TensorFlowInferenceInterface { if (n.equals(outputName)) { return fetchTensors.get(i); } - i++; + ++i; } - return null; + throw new RuntimeException( + "Node '" + outputName + "' was not provided to run(), so it cannot be read"); } private void closeFeeds() { @@ -507,10 +464,11 @@ public class TensorFlowInferenceInterface { } // State immutable between initializeTensorFlow calls. - private Graph g; - private Session sess; + private final String modelName; + private final Graph g; + private final Session sess; - // State reset on every call to runInference. + // State reset on every call to run. private Session.Runner runner; private List<String> feedNames = new ArrayList<String>(); private List<Tensor> feedTensors = new ArrayList<Tensor>(); @@ -518,6 +476,5 @@ public class TensorFlowInferenceInterface { private List<Tensor> fetchTensors = new ArrayList<Tensor>(); // Mutable state. - private boolean enableStats; private RunStats runStats; } |