aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/android/jni
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2016-09-20 12:42:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-20 13:48:15 -0700
commit54bd703932480790643c641796386c225e4e1cf3 (patch)
treec2b80560ee090aec553ab5def5ab6694eb60f444 /tensorflow/contrib/android/jni
parent7d4bee12533e027ae04be400712044da326563e8 (diff)
Add generic Java interface for TF inference.
Change: 133748240
Diffstat (limited to 'tensorflow/contrib/android/jni')
-rw-r--r--tensorflow/contrib/android/jni/jni_utils.cc139
-rw-r--r--tensorflow/contrib/android/jni/jni_utils.h41
-rw-r--r--tensorflow/contrib/android/jni/limiting_file_input_stream.h67
-rw-r--r--tensorflow/contrib/android/jni/tensorflow_inference_jni.cc270
-rw-r--r--tensorflow/contrib/android/jni/tensorflow_inference_jni.h63
-rw-r--r--tensorflow/contrib/android/jni/version_script.lds11
6 files changed, 591 insertions, 0 deletions
diff --git a/tensorflow/contrib/android/jni/jni_utils.cc b/tensorflow/contrib/android/jni/jni_utils.cc
new file mode 100644
index 0000000000..cf3e7ab252
--- /dev/null
+++ b/tensorflow/contrib/android/jni/jni_utils.cc
@@ -0,0 +1,139 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/android/jni/jni_utils.h"
+
+#include <android/asset_manager.h>
+#include <android/asset_manager_jni.h>
+#include <jni.h>
+#include <stdlib.h>
+
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "google/protobuf/io/coded_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+#include "google/protobuf/message_lite.h"
+#include "tensorflow/contrib/android/jni/limiting_file_input_stream.h"
+#include "tensorflow/core/platform/logging.h"
+
+static const char* const ASSET_PREFIX = "file:///android_asset/";
+
+namespace {
+class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
+ public:
+ explicit IfstreamInputStream(const std::string& file_name)
+ : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {}
+ ~IfstreamInputStream() { ifs_.close(); }
+
+ int Read(void* buffer, int size) {
+ if (!ifs_) {
+ return -1;
+ }
+ ifs_.read(static_cast<char*>(buffer), size);
+ return ifs_.gcount();
+ }
+
+ private:
+ std::ifstream ifs_;
+};
+
+} // namespace
+
+bool PortableReadFileToProto(const std::string& file_name,
+ ::google::protobuf::MessageLite* proto) {
+ ::google::protobuf::io::CopyingInputStreamAdaptor stream(
+ new IfstreamInputStream(file_name));
+ stream.SetOwnsCopyingStream(true);
+ // TODO(jiayq): the following coded stream is for debugging purposes to allow
+ // one to parse arbitrarily large messages for MessageLite. One most likely
+ // doesn't want to put protobufs larger than 64MB on Android, so we should
+ // eventually remove this and quit loud when a large protobuf is passed in.
+ ::google::protobuf::io::CodedInputStream coded_stream(&stream);
+ // Total bytes hard limit / warning limit are set to 1GB and 512MB
+ // respectively.
+ coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
+ return proto->ParseFromCodedStream(&coded_stream);
+}
+
+bool IsAsset(const char* const filename) {
+ return strstr(filename, ASSET_PREFIX) == filename;
+}
+
+void ReadFileToProtoOrDie(AAssetManager* const asset_manager,
+ const char* const filename,
+ google::protobuf::MessageLite* message) {
+ if (!IsAsset(filename)) {
+ VLOG(0) << "Opening file: " << filename;
+ CHECK(PortableReadFileToProto(filename, message));
+ return;
+ }
+
+ CHECK_NOTNULL(asset_manager);
+
+ const char* const asset_filename = filename + strlen(ASSET_PREFIX);
+ AAsset* asset =
+ AAssetManager_open(asset_manager, asset_filename, AASSET_MODE_STREAMING);
+ CHECK_NOTNULL(asset);
+
+ off_t start;
+ off_t length;
+ const int fd = AAsset_openFileDescriptor(asset, &start, &length);
+
+ if (fd >= 0) {
+ ::tensorflow::android::LimitingFileInputStream is(fd, start + length);
+ google::protobuf::io::CopyingInputStreamAdaptor adaptor(&is);
+ // If the file is smaller than protobuf's default limit, avoid copies.
+ if (length < (64 * 1024 * 1024)) {
+ // If it has a file descriptor that means it can be memmapped directly
+ // from the APK.
+ VLOG(0) << "Opening asset " << asset_filename
+ << " from disk with zero-copy.";
+ adaptor.Skip(start);
+ CHECK(message->ParseFromZeroCopyStream(&adaptor));
+ } else {
+ ::google::protobuf::io::CodedInputStream coded_stream(&adaptor);
+ // Total bytes hard limit / warning limit are set to 1GB and 512MB
+ // respectively.
+ coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
+ coded_stream.Skip(start);
+ CHECK(message->ParseFromCodedStream(&coded_stream));
+ }
+ } else {
+ // It may be compressed, in which case we have to uncompress
+ // it to memory first.
+ VLOG(0) << "Opening asset " << asset_filename << " from disk with copy.";
+ const off_t data_size = AAsset_getLength(asset);
+ const void* const memory = AAsset_getBuffer(asset);
+ CHECK(message->ParseFromArray(memory, data_size));
+ }
+ AAsset_close(asset);
+}
+
+std::string GetString(JNIEnv* env, jstring java_string) {
+ const char* raw_string = env->GetStringUTFChars(java_string, 0);
+ std::string return_str(raw_string);
+ env->ReleaseStringUTFChars(java_string, raw_string);
+ return return_str;
+}
+
+tensorflow::int64 CurrentWallTimeUs() {
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return tv.tv_sec * 1000000 + tv.tv_usec;
+}
diff --git a/tensorflow/contrib/android/jni/jni_utils.h b/tensorflow/contrib/android/jni/jni_utils.h
new file mode 100644
index 0000000000..c1d9060951
--- /dev/null
+++ b/tensorflow/contrib/android/jni/jni_utils.h
@@ -0,0 +1,41 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef ORG_TENSORFLOW_JNI_JNI_UTILS_H_ // NOLINT
+#define ORG_TENSORFLOW_JNI_JNI_UTILS_H_ // NOLINT
+
+#include <jni.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+class AAssetManager;
+
+bool PortableReadFileToProto(const std::string& file_name,
+ ::google::protobuf::MessageLite* proto)
+ TF_MUST_USE_RESULT;
+
+// Deserializes the contents of a file into memory.
+void ReadFileToProtoOrDie(AAssetManager* const asset_manager,
+ const char* const filename,
+ google::protobuf::MessageLite* message);
+
+std::string GetString(JNIEnv* env, jstring java_string);
+
+tensorflow::int64 CurrentWallTimeUs();
+
+#endif // ORG_TENSORFLOW_JNI_JNI_UTILS_H_
diff --git a/tensorflow/contrib/android/jni/limiting_file_input_stream.h b/tensorflow/contrib/android/jni/limiting_file_input_stream.h
new file mode 100644
index 0000000000..b79676a2d3
--- /dev/null
+++ b/tensorflow/contrib/android/jni/limiting_file_input_stream.h
@@ -0,0 +1,67 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_LIMITING_FILE_INPUT_STREAM_H_
+#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_LIMITING_FILE_INPUT_STREAM_H_
+
+#include <unistd.h>
+#include "google/protobuf/io/zero_copy_stream_impl.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
+
+namespace tensorflow {
+namespace android {
+
+// Input stream that reads a limited amount of data from an input file
+// descriptor.
+class LimitingFileInputStream
+ : public ::google::protobuf::io::CopyingInputStream {
+ public:
+ // Construct a stream to read from file <fd>, returning on the first <limit>
+ // bytes. If <fd> has fewer than <limit> bytes, then limit has no effect.
+ LimitingFileInputStream(int fd, int limit) : fd_(fd), bytes_left_(limit) {}
+ ~LimitingFileInputStream() {}
+
+ int Read(void* buffer, int size) {
+ int result;
+ do {
+ result = read(fd_, buffer, std::min(bytes_left_, size));
+ } while (result < 0 && errno == EINTR);
+
+ if (result < 0) {
+ errno_ = errno;
+ } else {
+ bytes_left_ -= result;
+ }
+ return result;
+ }
+
+ int Skip(int count) {
+ if (lseek(fd_, count, SEEK_CUR) == (off_t)-1) {
+ return -1;
+ }
+ // Seek succeeded.
+ bytes_left_ -= count;
+ return count;
+ }
+
+ private:
+ const int fd_;
+ int bytes_left_;
+ int errno_ = 0;
+};
+
+} // namespace android
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_LIMITING_FILE_INPUT_STREAM_H_
diff --git a/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc b/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
new file mode 100644
index 0000000000..593be3fa84
--- /dev/null
+++ b/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
@@ -0,0 +1,270 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/android/jni/tensorflow_inference_jni.h"
+
+#include <android/asset_manager.h>
+#include <android/asset_manager_jni.h>
+#include <android/bitmap.h>
+
+#include <jni.h>
+#include <pthread.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#include <map>
+#include <queue>
+#include <sstream>
+#include <string>
+
+#include "tensorflow/contrib/android/jni/jni_utils.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/stat_summarizer.h"
+
+using namespace tensorflow;
+
+typedef std::map<std::string, std::pair<std::string, tensorflow::Tensor> >
+ InputMap;
+
+// Variables associated with a single TF session.
+struct SessionVariables {
+ std::unique_ptr<tensorflow::Session> session;
+
+ long id = -1; // Copied from Java field for convenience.
+ int num_runs = 0;
+ int64 timing_total_us = 0;
+
+ InputMap input_tensors;
+ std::vector<std::string> output_tensor_names;
+ std::vector<tensorflow::Tensor> output_tensors;
+};
+
+static tensorflow::mutex mutex_(tensorflow::LINKER_INITIALIZED);
+
+std::map<int64, SessionVariables*>* GetSessionsSingleton() {
+ static std::map<int64, SessionVariables*>* sessions PT_GUARDED_BY(mutex_) =
+ new std::map<int64, SessionVariables*>();
+ return sessions;
+}
+
+inline static SessionVariables* GetSessionVars(JNIEnv* env, jobject thiz) {
+ jclass clazz = env->GetObjectClass(thiz);
+ assert(clazz != nullptr);
+ jfieldID fid = env->GetFieldID(clazz, "id", "J");
+ assert(fid != nullptr);
+ const int64 id = env->GetLongField(thiz, fid);
+
+ // This method is thread-safe as we support working with multiple
+ // sessions simultaneously. However care must be taken at the calling
+ // level on a per-session basis.
+ mutex_lock l(mutex_);
+ std::map<int64, SessionVariables*>& sessions = *GetSessionsSingleton();
+ if (sessions.find(id) == sessions.end()) {
+ LOG(INFO) << "Creating new session variables for " << std::hex << id;
+ SessionVariables* vars = new SessionVariables;
+ vars->id = id;
+ sessions[id] = vars;
+ } else {
+ VLOG(0) << "Found session variables for " << std::hex << id;
+ }
+ return sessions[id];
+}
+
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model) {
+ SessionVariables* vars = GetSessionVars(env, thiz);
+
+ if (vars->session.get() != nullptr) {
+ LOG(INFO) << "Compute graph already loaded. skipping.";
+ return 0;
+ }
+
+ const int64 start_time = CurrentWallTimeUs();
+
+ const std::string model_str = GetString(env, model);
+
+ LOG(INFO) << "Loading Tensorflow.";
+
+ LOG(INFO) << "Making new SessionOptions.";
+ tensorflow::SessionOptions options;
+ tensorflow::ConfigProto& config = options.config;
+ LOG(INFO) << "Got config, " << config.device_count_size() << " devices";
+
+ tensorflow::Session* session = tensorflow::NewSession(options);
+ vars->session.reset(session);
+ LOG(INFO) << "Session created.";
+
+ tensorflow::GraphDef tensorflow_graph;
+ LOG(INFO) << "Graph created.";
+
+ AAssetManager* const asset_manager =
+ AAssetManager_fromJava(env, java_asset_manager);
+ LOG(INFO) << "Acquired AssetManager.";
+
+ LOG(INFO) << "Reading file to proto: " << model_str;
+ ReadFileToProtoOrDie(asset_manager, model_str.c_str(), &tensorflow_graph);
+
+ LOG(INFO) << "Creating session.";
+ tensorflow::Status s = session->Create(tensorflow_graph);
+
+ // Clear the proto to save memory space.
+ tensorflow_graph.Clear();
+
+ if (!s.ok()) {
+ LOG(ERROR) << "Could not create Tensorflow Graph: " << s;
+ return s.code();
+ }
+
+ LOG(INFO) << "Tensorflow graph loaded from: " << model_str;
+
+ const int64 end_time = CurrentWallTimeUs();
+ LOG(INFO) << "Initialization done in " << (end_time - start_time) / 1000.0
+ << "ms";
+
+ return s.code();
+}
+
+static tensorflow::Tensor* GetTensor(JNIEnv* env, jobject thiz,
+ jstring node_name_jstring) {
+ SessionVariables* vars = GetSessionVars(env, thiz);
+ std::string node_name = GetString(env, node_name_jstring);
+
+ int output_index = -1;
+ for (int i = 0; i < vars->output_tensors.size(); ++i) {
+ if (vars->output_tensor_names[i] == node_name) {
+ output_index = i;
+ break;
+ }
+ }
+ if (output_index == -1) {
+ LOG(ERROR) << "Output [" << node_name << "] not found, aborting!";
+ return nullptr;
+ }
+
+ tensorflow::Tensor* output = &vars->output_tensors[output_index];
+ return output;
+}
+
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
+ JNIEnv* env, jobject thiz, jobjectArray output_name_strings) {
+ SessionVariables* vars = GetSessionVars(env, thiz);
+
+ // Add the requested outputs to the output list.
+ vars->output_tensor_names.clear();
+ for (int i = 0; i < env->GetArrayLength(output_name_strings); i++) {
+ jstring java_string =
+ (jstring)(env->GetObjectArrayElement(output_name_strings, i));
+ std::string output_name = GetString(env, java_string);
+ vars->output_tensor_names.push_back(output_name);
+ }
+
+ ++(vars->num_runs);
+ tensorflow::Status s;
+ int64 start_time, end_time;
+
+ start_time = CurrentWallTimeUs();
+
+ std::vector<std::pair<std::string, tensorflow::Tensor> > input_tensors;
+ for (const auto& entry : vars->input_tensors) {
+ input_tensors.push_back(entry.second);
+ }
+
+ vars->output_tensors.clear();
+ s = vars->session->Run(input_tensors, vars->output_tensor_names, {},
+ &(vars->output_tensors));
+ end_time = CurrentWallTimeUs();
+ const int64 elapsed_time_inf = end_time - start_time;
+ vars->timing_total_us += elapsed_time_inf;
+ VLOG(0) << "End computing. Ran in " << elapsed_time_inf / 1000 << "ms ("
+ << (vars->timing_total_us / vars->num_runs / 1000) << "ms avg over "
+ << vars->num_runs << " runs)";
+
+ if (!s.ok()) {
+ LOG(ERROR) << "Error during inference: " << s;
+ }
+ return s.code();
+}
+
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
+ mutex_lock l(mutex_);
+ SessionVariables* vars = GetSessionVars(env, thiz);
+
+ tensorflow::Status s = vars->session->Close();
+ if (!s.ok()) {
+ LOG(ERROR) << "Error closing session: " << s;
+ }
+
+ std::map<int64, SessionVariables*>& sessions = *GetSessionsSingleton();
+ sessions.erase(vars->id);
+ delete vars;
+
+ return s.code();
+}
+
+// TODO(andrewharp): Use memcpy to fill/read nodes.
+#define FILL_NODE_METHOD(DTYPE, JAVA_DTYPE, TENSOR_DTYPE) \
+ FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
+ SessionVariables* vars = GetSessionVars(env, thiz); \
+ tensorflow::Tensor input_tensor(TENSOR_DTYPE, \
+ tensorflow::TensorShape({x, y, z, d})); \
+ auto tensor_mapped = input_tensor.flat<JAVA_DTYPE>(); \
+ jboolean iCopied = JNI_FALSE; \
+ j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(arr, &iCopied); \
+ j##JAVA_DTYPE* value_ptr = values; \
+ const int array_size = env->GetArrayLength(arr); \
+ for (int i = 0; i < std::min(tensor_mapped.size(), array_size); ++i) { \
+ tensor_mapped(i) = *value_ptr++; \
+ } \
+ env->Release##DTYPE##ArrayElements(arr, values, JNI_ABORT); \
+ std::string input_name = GetString(env, node_name); \
+ std::pair<std::string, tensorflow::Tensor> input_pair(input_name, \
+ input_tensor); \
+ vars->input_tensors[input_name] = input_pair; \
+ }
+
+#define READ_NODE_METHOD(DTYPE, JAVA_DTYPE) \
+ READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
+ SessionVariables* vars = GetSessionVars(env, thiz); \
+ Tensor* t = GetTensor(env, thiz, node_name_jstring); \
+ if (t == nullptr) { \
+ return -1; \
+ } \
+ auto tensor_mapped = t->flat<JAVA_DTYPE>(); \
+ jboolean iCopied = JNI_FALSE; \
+ j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(arr, &iCopied); \
+ j##JAVA_DTYPE* value_ptr = values; \
+ const int num_items = std::min(static_cast<int>(tensor_mapped.size()), \
+ env->GetArrayLength(arr)); \
+ for (int i = 0; i < num_items; ++i) { \
+ *value_ptr++ = tensor_mapped(i); \
+ } \
+ env->Release##DTYPE##ArrayElements(arr, values, 0); \
+ return 0; \
+ }
+
+FILL_NODE_METHOD(Float, float, tensorflow::DT_FLOAT)
+FILL_NODE_METHOD(Int, int, tensorflow::DT_INT32)
+FILL_NODE_METHOD(Double, double, tensorflow::DT_DOUBLE)
+
+READ_NODE_METHOD(Float, float)
+READ_NODE_METHOD(Int, int)
+READ_NODE_METHOD(Double, double)
diff --git a/tensorflow/contrib/android/jni/tensorflow_inference_jni.h b/tensorflow/contrib/android/jni/tensorflow_inference_jni.h
new file mode 100644
index 0000000000..6c04553a94
--- /dev/null
+++ b/tensorflow/contrib/android/jni/tensorflow_inference_jni.h
@@ -0,0 +1,63 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The methods are exposed to Java to allow for interaction with the native
+// TensorFlow code. See
+// tensorflow/examples/android/src/org/tensorflow/TensorFlowClassifier.java
+// for the Java counterparts.
+
+#ifndef ORG_TENSORFLOW_JNI_TENSORFLOW_JNI_H_ // NOLINT
+#define ORG_TENSORFLOW_JNI_TENSORFLOW_JNI_H_ // NOLINT
+
+#include <jni.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+#define TENSORFLOW_METHOD(METHOD_NAME) \
+ Java_org_tensorflow_contrib_android_TensorFlowInferenceInterface_##METHOD_NAME // NOLINT
+
+#define FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
+ JNIEXPORT void TENSORFLOW_METHOD(fillNode##DTYPE)( \
+ JNIEnv * env, jobject thiz, jstring node_name, jint x, jint y, jint z, \
+ jint d, j##JAVA_DTYPE##Array arr)
+
+#define READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
+ JNIEXPORT jint TENSORFLOW_METHOD(readNode##DTYPE)( \
+ JNIEnv * env, jobject thiz, jstring node_name_jstring, \
+ j##JAVA_DTYPE##Array arr)
+
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
+ JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model);
+
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
+ JNIEnv* env, jobject thiz, jobjectArray output_name_strings);
+
+JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz);
+
+FILL_NODE_SIGNATURE(Float, float);
+FILL_NODE_SIGNATURE(Int, int);
+FILL_NODE_SIGNATURE(Double, double);
+
+READ_NODE_SIGNATURE(Float, float);
+READ_NODE_SIGNATURE(Int, int);
+READ_NODE_SIGNATURE(Double, double);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // ORG_TENSORFLOW_JNI_TENSORFLOW_JNI_H_ // NOLINT
diff --git a/tensorflow/contrib/android/jni/version_script.lds b/tensorflow/contrib/android/jni/version_script.lds
new file mode 100644
index 0000000000..38c93dda73
--- /dev/null
+++ b/tensorflow/contrib/android/jni/version_script.lds
@@ -0,0 +1,11 @@
+VERS_1.0 {
+ # Export JNI symbols.
+ global:
+ Java_*;
+ JNI_OnLoad;
+ JNI_OnUnload;
+
+ # Hide everything else.
+ local:
+ *;
+};