aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/android
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2017-08-25 14:01:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-25 14:04:48 -0700
commit008910f1122d115a6d7430bfcc63cf4296c7467d (patch)
treee50199dcceed004cecc8510f9251f5e04734800f /tensorflow/contrib/android
parent005a88f6cc6e4e8c94a4f2d1980737855c4592f4 (diff)
Merge changes from github.
END_PUBLIC --- Commit b30ce4714 authored by James Qin<jamesqin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Revamp CudnnRNN Saveables 1. Use a lossy way to save/restore cudnn biases during checkpointing. Cudnn uses 2 biases each gate for all RNNs while tf uses one. To allow cudnn checkpoints to be compatible with both Cudnn and platform-independent impls, previously both individual bias and summed biases each gate were stored. The new way only stores the bias sum for each gate, and split it half-half when restoring from a cudnn graph. Doing this does not cause problems since RNNs do not use weight-decay to regularize. 2. Use inheritance instead of branching * Split RNNParamsSaveable to 1 base class and 4 subclasses. * Extract common routines and only overwrite rnn-type-specific pieces in subclasses. PiperOrigin-RevId: 166413989 --- Commit ebc421daf authored by Alan Yee<alyee@ucsd.edu> Committed by Jonathan Hseu<vomjom@vomjom.net>: Update documentation for contrib (#12424) * Update __init__.py Remove ## for standardization of api docs * Create README.md Add README to define this directory's purpose * Update __init.py Markdown styling does not show up well in api docs * Update README.md Add short mention of describing what to deprecate * Update README.md Capitalize title * Update README.md Revert README change * Delete README.md --- Commit fd295394d authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use latest version of nsync library, which now allows use of cmake on MacOS. PiperOrigin-RevId: 166411437 --- Commit 587d728e0 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Refactor reduce-precision-insertion filters, add several more options. In particular, this adds the ability to add reduce-precision operations after fusion nodes based on the contents of those fusion nodes, and the ability to filter operations based on the "op_name" metadata. PiperOrigin-RevId: 166408392 --- Commit 3142f8ef5 authored by Ali Yahya<alive@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Steps toward making ResourceVariables compatible with Eager. This change forces the value of the reuse flag in variable scopes to be tf.AUTO_REUSE when in Eager mode. This change also adds comprehensive Eager tests for ResourceVariable. PiperOrigin-RevId: 166408161 --- Commit b2ce45150 authored by Igor Ganichev<iga@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make Graph::IsValidNode public It can be reimplemented with existing public APIs, but instead of doing so, making this one public seems better. PiperOrigin-RevId: 166407897 --- Commit 0a2f40e92 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA::CPU] Fix HLO profiling in parallel CPU backend. PiperOrigin-RevId: 166400211 --- Commit c4a58e3fd authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Identify frame ids for all nodes in a graph. PiperOrigin-RevId: 166397615 --- Commit 989713f26 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 166294015 PiperOrigin-RevId: 166521502
Diffstat (limited to 'tensorflow/contrib/android')
-rw-r--r--tensorflow/contrib/android/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java63
2 files changed, 47 insertions, 18 deletions
diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt
index 1f86288cf9..f61e9560ef 100644
--- a/tensorflow/contrib/android/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/android/cmake/CMakeLists.txt
@@ -28,7 +28,7 @@ set_target_properties(lib_proto PROPERTIES IMPORTED_LOCATION
add_library(lib_nsync STATIC IMPORTED )
set_target_properties(lib_nsync PROPERTIES IMPORTED_LOCATION
- ${TARGET_NSYNC_LIB})
+ ${TARGET_NSYNC_LIB}/lib/libnsync.a)
add_library(lib_tf STATIC IMPORTED )
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 587f2941e5..9b7f394258 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -55,23 +55,7 @@ public class TensorFlowInferenceInterface {
* @param model The filepath to the GraphDef proto representing the model.
*/
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
- Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
- try {
- // Hack to see if the native libraries have been loaded.
- new RunStats();
- Log.i(TAG, "TensorFlow native methods already loaded");
- } catch (UnsatisfiedLinkError e1) {
- Log.i(
- TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
- try {
- System.loadLibrary("tensorflow_inference");
- Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
- } catch (UnsatisfiedLinkError e2) {
- throw new RuntimeException(
- "Native TF methods not found; check that the correct native"
- + " libraries are present in the APK.");
- }
- }
+ prepareNativeRuntime();
this.modelName = model;
this.g = new Graph();
@@ -102,6 +86,31 @@ public class TensorFlowInferenceInterface {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
+
+ /*
+ * Load a TensorFlow model from provided InputStream.
+ * Note: The InputStream will not be closed after loading model, users need to
+ * close it themselves.
+ *
+ * @param is The InputStream to use to load the model.
+ */
+ public TensorFlowInferenceInterface(InputStream is) {
+ prepareNativeRuntime();
+
+ // modelName is redundant for model loading from input stream, here is for
+ // avoiding error in initialization as modelName is marked final.
+ this.modelName = "";
+ this.g = new Graph();
+ this.sess = new Session(g);
+ this.runner = sess.runner();
+
+ try {
+ loadGraph(is, g);
+ Log.i(TAG, "Successfully loaded model from the input stream");
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to load model from the input stream", e);
+ }
+ }
/**
* Runs inference between the previously registered input nodes (via feed*) and the requested
@@ -408,6 +417,26 @@ public class TensorFlowInferenceInterface {
public void fetch(String outputName, ByteBuffer dst) {
getTensor(outputName).writeTo(dst);
}
+
+ private void prepareNativeRuntime() {
+ Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
+ try {
+ // Hack to see if the native libraries have been loaded.
+ new RunStats();
+ Log.i(TAG, "TensorFlow native methods already loaded");
+ } catch (UnsatisfiedLinkError e1) {
+ Log.i(
+ TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
+ try {
+ System.loadLibrary("tensorflow_inference");
+ Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
+ } catch (UnsatisfiedLinkError e2) {
+ throw new RuntimeException(
+ "Native TF methods not found; check that the correct native"
+ + " libraries are present in the APK.");
+ }
+ }
+ }
private void loadGraph(InputStream is, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();