aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/android/jni
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-02-13 15:05:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-13 17:16:29 -0800
commit1a9769dc79fdd27c347633df210ff64f48de8d07 (patch)
tree0e432d1f7a62db6f43e3f0a2f35e526b678ec7b6 /tensorflow/contrib/android/jni
parent1b960c30ab87508ca720a1ed83af3d008fd39dd3 (diff)
contrib/android: Use the Java API to implement TensorFlowInferenceInterface
This removes the native code for tensorflow_inference_jni and makes it a pure Java class. Use of the Java API allows some "configuration" values (like the number of labels in the inception example) to be determined using shape inference instead of being passed in explicitly. Change: 147398845
Diffstat (limited to 'tensorflow/contrib/android/jni')
-rw-r--r--tensorflow/contrib/android/jni/jni_utils.cc152
-rw-r--r--tensorflow/contrib/android/jni/jni_utils.h42
-rw-r--r--tensorflow/contrib/android/jni/limiting_file_input_stream.h68
-rw-r--r--tensorflow/contrib/android/jni/tensorflow_inference_jni.cc405
-rw-r--r--tensorflow/contrib/android/jni/tensorflow_inference_jni.h91
5 files changed, 0 insertions, 758 deletions
diff --git a/tensorflow/contrib/android/jni/jni_utils.cc b/tensorflow/contrib/android/jni/jni_utils.cc
deleted file mode 100644
index 71a93ea1f1..0000000000
--- a/tensorflow/contrib/android/jni/jni_utils.cc
+++ /dev/null
@@ -1,152 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/android/jni/jni_utils.h"
-
-#include <android/asset_manager.h>
-#include <android/asset_manager_jni.h>
-#include <jni.h>
-#include <stdlib.h>
-
-#include <fstream>
-#include <sstream>
-#include <string>
-#include <vector>
-
-#include "google/protobuf/io/coded_stream.h"
-#include "google/protobuf/io/zero_copy_stream_impl.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
-#include "google/protobuf/message_lite.h"
-#include "tensorflow/contrib/android/jni/limiting_file_input_stream.h"
-#include "tensorflow/core/platform/logging.h"
-
-static const char* const ASSET_PREFIX = "file:///android_asset/";
-
-namespace {
-class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
- public:
- explicit IfstreamInputStream(const std::string& file_name)
- : ifs_(file_name.c_str(), std::ios::in | std::ios::binary) {
- CHECK(ifs_.good()) << "Failed to open file \"" << file_name
- << "\" or file is 0 length! Use prefix \""
- << ASSET_PREFIX
- << "\" if attempting to load proto from assets.";
- }
- ~IfstreamInputStream() { ifs_.close(); }
-
- int Read(void* buffer, int size) {
- if (!ifs_) {
- return -1;
- }
- ifs_.read(static_cast<char*>(buffer), size);
- return ifs_.gcount();
- }
-
- private:
- std::ifstream ifs_;
-};
-
-} // namespace
-
-bool PortableReadFileToProto(const std::string& file_name,
- ::google::protobuf::MessageLite* proto) {
- ::google::protobuf::io::CopyingInputStreamAdaptor stream(
- new IfstreamInputStream(file_name));
-
- stream.SetOwnsCopyingStream(true);
- // TODO(jiayq): the following coded stream is for debugging purposes to allow
- // one to parse arbitrarily large messages for MessageLite. One most likely
- // doesn't want to put protobufs larger than 64MB on Android, so we should
- // eventually remove this and quit loud when a large protobuf is passed in.
- ::google::protobuf::io::CodedInputStream coded_stream(&stream);
- // Total bytes hard limit / warning limit are set to 1GB and 512MB
- // respectively.
- coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
- return proto->ParseFromCodedStream(&coded_stream);
-}
-
-bool IsAsset(const char* const filename) {
- return strstr(filename, ASSET_PREFIX) == filename;
-}
-
-void ReadFileToProtoOrDie(AAssetManager* const asset_manager,
- const char* const filename,
- google::protobuf::MessageLite* message) {
- if (!IsAsset(filename)) {
- VLOG(0) << "Opening file: " << filename;
- CHECK(PortableReadFileToProto(filename, message));
- return;
- }
-
- CHECK_NOTNULL(asset_manager);
-
- const char* const asset_filename = filename + strlen(ASSET_PREFIX);
- AAsset* asset =
- AAssetManager_open(asset_manager, asset_filename, AASSET_MODE_STREAMING);
- CHECK_NOTNULL(asset);
-
- off_t start;
- off_t length;
- const int fd = AAsset_openFileDescriptor(asset, &start, &length);
-
- if (fd >= 0) {
- ::tensorflow::android::LimitingFileInputStream is(fd, start + length);
- google::protobuf::io::CopyingInputStreamAdaptor adaptor(&is);
- // If the file is smaller than protobuf's default limit, avoid copies.
- if (length < (64 * 1024 * 1024)) {
- // If it has a file descriptor that means it can be memmapped directly
- // from the APK.
- VLOG(0) << "Opening asset " << asset_filename
- << " from disk with zero-copy.";
- adaptor.Skip(start);
- CHECK(message->ParseFromZeroCopyStream(&adaptor));
- } else {
- ::google::protobuf::io::CodedInputStream coded_stream(&adaptor);
- // Total bytes hard limit / warning limit are set to 1GB and 512MB
- // respectively.
- coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
- coded_stream.Skip(start);
- CHECK(message->ParseFromCodedStream(&coded_stream));
- }
- } else {
- // It may be compressed, in which case we have to uncompress
- // it to memory first.
- VLOG(0) << "Opening asset " << asset_filename << " from disk with copy.";
- const off_t data_size = AAsset_getLength(asset);
-
- // TODO(andrewharp): Add codepath for loading compressed protos as well.
- if (data_size > 64 * 1024 * 1024) {
- LOG(WARNING) << "Compressed proto is larger than 64mb; if problems occur "
- << " turn off compression for protocol buffer files in APK.";
- }
-
- const void* const memory = AAsset_getBuffer(asset);
- CHECK(message->ParseFromArray(memory, data_size));
- }
- AAsset_close(asset);
-}
-
-std::string GetString(JNIEnv* env, jstring java_string) {
- const char* raw_string = env->GetStringUTFChars(java_string, 0);
- std::string return_str(raw_string);
- env->ReleaseStringUTFChars(java_string, raw_string);
- return return_str;
-}
-
-tensorflow::int64 CurrentWallTimeUs() {
- struct timeval tv;
- gettimeofday(&tv, NULL);
- return tv.tv_sec * 1000000 + tv.tv_usec;
-}
diff --git a/tensorflow/contrib/android/jni/jni_utils.h b/tensorflow/contrib/android/jni/jni_utils.h
deleted file mode 100644
index 7cef1e8396..0000000000
--- a/tensorflow/contrib/android/jni/jni_utils.h
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef ORG_TENSORFLOW_JNI_JNI_UTILS_H_ // NOLINT
-#define ORG_TENSORFLOW_JNI_JNI_UTILS_H_ // NOLINT
-
-#include <jni.h>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/types.h"
-
-class AAssetManager;
-
-bool PortableReadFileToProto(const std::string& file_name,
- ::google::protobuf::MessageLite* proto)
- TF_MUST_USE_RESULT;
-
-// Deserializes the contents of a file into memory.
-void ReadFileToProtoOrDie(AAssetManager* const asset_manager,
- const char* const filename,
- google::protobuf::MessageLite* message);
-
-std::string GetString(JNIEnv* env, jstring java_string);
-
-tensorflow::int64 CurrentWallTimeUs();
-
-#endif // ORG_TENSORFLOW_JNI_JNI_UTILS_H_
diff --git a/tensorflow/contrib/android/jni/limiting_file_input_stream.h b/tensorflow/contrib/android/jni/limiting_file_input_stream.h
deleted file mode 100644
index fb3cb59719..0000000000
--- a/tensorflow/contrib/android/jni/limiting_file_input_stream.h
+++ /dev/null
@@ -1,68 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_LIMITING_FILE_INPUT_STREAM_H_
-#define THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_LIMITING_FILE_INPUT_STREAM_H_
-
-#include <errno.h>
-#include <unistd.h>
-#include "google/protobuf/io/zero_copy_stream_impl.h"
-#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
-
-namespace tensorflow {
-namespace android {
-
-// Input stream that reads a limited amount of data from an input file
-// descriptor.
-class LimitingFileInputStream
- : public ::google::protobuf::io::CopyingInputStream {
- public:
- // Construct a stream to read from file <fd>, returning on the first <limit>
- // bytes. If <fd> has fewer than <limit> bytes, then limit has no effect.
- LimitingFileInputStream(int fd, int limit) : fd_(fd), bytes_left_(limit) {}
- ~LimitingFileInputStream() {}
-
- int Read(void* buffer, int size) {
- int result;
- do {
- result = read(fd_, buffer, std::min(bytes_left_, size));
- } while (result < 0 && errno == EINTR);
-
- if (result < 0) {
- errno_ = errno;
- } else {
- bytes_left_ -= result;
- }
- return result;
- }
-
- int Skip(int count) {
- if (lseek(fd_, count, SEEK_CUR) == (off_t)-1) {
- return -1;
- }
- // Seek succeeded.
- bytes_left_ -= count;
- return count;
- }
-
- private:
- const int fd_;
- int bytes_left_;
- int errno_ = 0;
-};
-
-} // namespace android
-} // namespace tensorflow
-
-#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_ANDROID_JNI_LIMITING_FILE_INPUT_STREAM_H_
diff --git a/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc b/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
deleted file mode 100644
index d3cfe1fdf0..0000000000
--- a/tensorflow/contrib/android/jni/tensorflow_inference_jni.cc
+++ /dev/null
@@ -1,405 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/android/jni/tensorflow_inference_jni.h"
-
-#include <android/asset_manager.h>
-#include <android/asset_manager_jni.h>
-#include <android/bitmap.h>
-
-#include <jni.h>
-#include <pthread.h>
-#include <sys/stat.h>
-#include <unistd.h>
-#include <map>
-#include <queue>
-#include <sstream>
-#include <string>
-
-#include "tensorflow/contrib/android/jni/jni_utils.h"
-#include "tensorflow/core/framework/step_stats.pb.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/public/session.h"
-#include "tensorflow/core/util/stat_summarizer.h"
-
-using namespace tensorflow;
-
-typedef std::map<std::string, std::pair<std::string, tensorflow::Tensor> >
- InputMap;
-
-// Variables associated with a single TF session.
-struct SessionVariables {
- std::unique_ptr<tensorflow::Session> session;
-
- int64 id = -1; // Copied from Java field for convenience.
- int num_runs = 0;
- int64 timing_total_us = 0;
-
- bool log_stats = false;
- StatSummarizer* summarizer = nullptr;
-
- InputMap input_tensors;
- std::vector<std::string> output_tensor_names;
- std::vector<tensorflow::Tensor> output_tensors;
-};
-
-static tensorflow::mutex mutex_(tensorflow::LINKER_INITIALIZED);
-
-std::map<int64, SessionVariables*>* GetSessionsSingleton() {
- static std::map<int64, SessionVariables*>* sessions PT_GUARDED_BY(mutex_) =
- new std::map<int64, SessionVariables*>();
- return sessions;
-}
-
-inline static SessionVariables* GetSessionVars(JNIEnv* env, jobject thiz) {
- jclass clazz = env->GetObjectClass(thiz);
- assert(clazz != nullptr);
- jfieldID fid = env->GetFieldID(clazz, "id", "J");
- assert(fid != nullptr);
- const int64 id = env->GetLongField(thiz, fid);
-
- // This method is thread-safe as we support working with multiple
- // sessions simultaneously. However care must be taken at the calling
- // level on a per-session basis.
- mutex_lock l(mutex_);
- std::map<int64, SessionVariables*>& sessions = *GetSessionsSingleton();
- if (sessions.find(id) == sessions.end()) {
- LOG(INFO) << "Creating new session variables for " << std::hex << id;
- SessionVariables* vars = new SessionVariables;
- vars->id = id;
- sessions[id] = vars;
- } else {
- VLOG(1) << "Found session variables for " << std::hex << id;
- }
- return sessions[id];
-}
-
-JNIEXPORT void JNICALL TENSORFLOW_METHOD(testLoaded)(JNIEnv* env,
- jobject thiz) {
- LOG(INFO) << "Native TF methods loaded.";
-}
-
-JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
- JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model) {
- SessionVariables* vars = GetSessionVars(env, thiz);
-
- if (vars->session.get() != nullptr) {
- LOG(INFO) << "Compute graph already loaded. skipping.";
- return 0;
- }
-
- const int64 start_time = CurrentWallTimeUs();
-
- const std::string model_str = GetString(env, model);
-
- LOG(INFO) << "Loading Tensorflow.";
-
- tensorflow::SessionOptions options;
- tensorflow::ConfigProto& config = options.config;
-
- tensorflow::Session* session = tensorflow::NewSession(options);
- vars->session.reset(session);
- LOG(INFO) << "Session created.";
-
- tensorflow::GraphDef tensorflow_graph;
-
- AAssetManager* const asset_manager =
- AAssetManager_fromJava(env, java_asset_manager);
- LOG(INFO) << "Acquired AssetManager.";
-
- LOG(INFO) << "Reading file to proto: " << model_str;
- ReadFileToProtoOrDie(asset_manager, model_str.c_str(), &tensorflow_graph);
- CHECK(tensorflow_graph.node_size() > 0) << "Problem loading GraphDef!";
-
- LOG(INFO) << "GraphDef loaded from " << model_str << " with "
- << tensorflow_graph.node_size() << " nodes.";
-
- // Whether or not stat logging is currently enabled, the StatSummarizer must
- // be initialized here with the GraphDef while it is available.
- vars->summarizer = new StatSummarizer(tensorflow_graph);
-
- LOG(INFO) << "Creating TensorFlow graph from GraphDef.";
- tensorflow::Status s = session->Create(tensorflow_graph);
-
- // Clear the proto to save memory space.
- tensorflow_graph.Clear();
-
- if (!s.ok()) {
- LOG(ERROR) << "Could not create TensorFlow graph: " << s;
- return s.code();
- }
-
- const int64 end_time = CurrentWallTimeUs();
- LOG(INFO) << "Initialization done in " << (end_time - start_time) / 1000.0
- << "ms";
-
- return s.code();
-}
-
-static tensorflow::Tensor* GetTensor(JNIEnv* env, jobject thiz,
- jstring node_name_jstring) {
- SessionVariables* vars = GetSessionVars(env, thiz);
- std::string node_name = GetString(env, node_name_jstring);
-
- int output_index = -1;
- for (int i = 0; i < vars->output_tensors.size(); ++i) {
- if (vars->output_tensor_names[i] == node_name) {
- output_index = i;
- break;
- }
- }
- if (output_index == -1) {
- LOG(ERROR) << "Output [" << node_name << "] not found, aborting!";
- return nullptr;
- }
-
- tensorflow::Tensor* output = &vars->output_tensors[output_index];
- return output;
-}
-
-JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
- JNIEnv* env, jobject thiz, jobjectArray output_name_strings) {
- SessionVariables* vars = GetSessionVars(env, thiz);
-
- // Add the requested outputs to the output list.
- vars->output_tensor_names.clear();
- for (int i = 0; i < env->GetArrayLength(output_name_strings); i++) {
- jstring java_string =
- (jstring)(env->GetObjectArrayElement(output_name_strings, i));
- std::string output_name = GetString(env, java_string);
- vars->output_tensor_names.push_back(output_name);
- }
-
- ++(vars->num_runs);
- tensorflow::Status s;
- int64 start_time, end_time;
-
- start_time = CurrentWallTimeUs();
-
- std::vector<std::pair<std::string, tensorflow::Tensor> > input_tensors;
- for (const auto& entry : vars->input_tensors) {
- input_tensors.push_back(entry.second);
- }
-
- vars->output_tensors.clear();
-
- if (vars->log_stats) {
- RunOptions run_options;
- run_options.set_trace_level(RunOptions::FULL_TRACE);
- RunMetadata run_metadata;
-
- s = vars->session->Run(run_options, input_tensors,
- vars->output_tensor_names, {},
- &(vars->output_tensors), &run_metadata);
-
- assert(run_metadata.has_step_stats());
- const StepStats& step_stats = run_metadata.step_stats();
- vars->summarizer->ProcessStepStats(step_stats);
-
- // Print the full output string, not just the abbreviated one returned by
- // getStatString().
- vars->summarizer->PrintStepStats();
- } else {
- s = vars->session->Run(input_tensors, vars->output_tensor_names, {},
- &(vars->output_tensors));
- }
-
- end_time = CurrentWallTimeUs();
- const int64 elapsed_time_inf = end_time - start_time;
- vars->timing_total_us += elapsed_time_inf;
- VLOG(0) << "End computing. Ran in " << elapsed_time_inf / 1000 << "ms ("
- << (vars->timing_total_us / vars->num_runs / 1000) << "ms avg over "
- << vars->num_runs << " runs)";
-
- if (!s.ok()) {
- LOG(ERROR) << "Error during inference: " << s;
- }
- return s.code();
-}
-
-JNIEXPORT void JNICALL TENSORFLOW_METHOD(enableStatLogging)(
- JNIEnv* env, jobject thiz, jboolean enableStatLogging) {
- SessionVariables* vars = GetSessionVars(env, thiz);
- vars->log_stats = enableStatLogging;
-}
-
-JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(getStatString)(JNIEnv* env,
- jobject thiz) {
- // Return an abbreviated stat string suitable for displaying on screen.
- SessionVariables* vars = GetSessionVars(env, thiz);
- std::stringstream ss;
- ss << vars->summarizer->GetStatsByMetric("Top 10 CPU",
- StatSummarizer::BY_TIME, 10);
- ss << vars->summarizer->GetStatsByNodeType();
- ss << vars->summarizer->ShortSummary();
- return env->NewStringUTF(ss.str().c_str());
-}
-
-JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz) {
- SessionVariables* vars = GetSessionVars(env, thiz);
-
- tensorflow::Status s = vars->session->Close();
- if (!s.ok()) {
- LOG(ERROR) << "Error closing session: " << s;
- }
-
- delete vars->summarizer;
-
- mutex_lock l(mutex_);
- std::map<int64, SessionVariables*>& sessions = *GetSessionsSingleton();
- sessions.erase(vars->id);
- delete vars;
-
- return s.code();
-}
-
-// TODO(andrewharp): Use memcpy to fill/read nodes.
-#define FILL_NODE_METHOD(DTYPE, JAVA_DTYPE, CTYPE, TENSOR_DTYPE) \
- FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
- SessionVariables* vars = GetSessionVars(env, thiz); \
- jboolean iCopied = JNI_FALSE; \
- tensorflow::TensorShape shape; \
- jint* dim_vals = env->GetIntArrayElements(dims, &iCopied); \
- const int num_dims = env->GetArrayLength(dims); \
- for (int i = 0; i < num_dims; ++i) { \
- shape.AddDim(dim_vals[i]); \
- } \
- env->ReleaseIntArrayElements(dims, dim_vals, JNI_ABORT); \
- tensorflow::Tensor input_tensor(TENSOR_DTYPE, shape); \
- auto tensor_mapped = input_tensor.flat<CTYPE>(); \
- j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(src, &iCopied); \
- j##JAVA_DTYPE* value_ptr = values; \
- const int src_size = static_cast<int>(env->GetArrayLength(src)); \
- const int dst_size = static_cast<int>(tensor_mapped.size()); \
- CHECK_GE(src_size, dst_size) \
- << "src array must have at least as many elements as dst Tensor."; \
- const int num_items = std::min(src_size, dst_size); \
- for (int i = 0; i < num_items; ++i) { \
- tensor_mapped(i) = *value_ptr++; \
- } \
- env->Release##DTYPE##ArrayElements(src, values, JNI_ABORT); \
- std::string input_name = GetString(env, node_name); \
- std::pair<std::string, tensorflow::Tensor> input_pair(input_name, \
- input_tensor); \
- vars->input_tensors[input_name] = input_pair; \
- }
-
-#define FILL_NODE_NIO_BUFFER_METHOD(DTYPE, CTYPE, TENSOR_DTYPE) \
- FILL_NODE_NIO_BUFFER_SIGNATURE(DTYPE) { \
- SessionVariables* vars = GetSessionVars(env, thiz); \
- tensorflow::TensorShape shape; \
- const int* dim_vals = reinterpret_cast<const int*>( \
- env->GetDirectBufferAddress(dims_buffer)); \
- const int num_dims = env->GetDirectBufferCapacity(dims_buffer); \
- for (int i = 0; i < num_dims; ++i) { \
- shape.AddDim(dim_vals[i]); \
- } \
- tensorflow::Tensor input_tensor(TENSOR_DTYPE, shape); \
- auto tensor_mapped = input_tensor.flat<CTYPE>(); \
- const CTYPE* values = reinterpret_cast<const CTYPE*>( \
- env->GetDirectBufferAddress(src_buffer)); \
- const CTYPE* value_ptr = values; \
- const int src_size = \
- static_cast<int>(env->GetDirectBufferCapacity(src_buffer)); \
- const int dst_size = static_cast<int>(tensor_mapped.size()); \
- CHECK_GE(src_size, dst_size) \
- << "src buffer must have at least as many elements as dst Tensor."; \
- const int num_items = std::min(src_size, dst_size); \
- for (int i = 0; i < num_items; ++i) { \
- tensor_mapped(i) = *value_ptr++; \
- } \
- std::string input_name = GetString(env, node_name); \
- std::pair<std::string, tensorflow::Tensor> input_pair(input_name, \
- input_tensor); \
- vars->input_tensors[input_name] = input_pair; \
- }
-
-#define READ_NODE_METHOD(DTYPE, JAVA_DTYPE, CTYPE) \
- READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) { \
- SessionVariables* vars = GetSessionVars(env, thiz); \
- Tensor* t = GetTensor(env, thiz, node_name); \
- if (t == nullptr) { \
- return -1; \
- } \
- auto tensor_mapped = t->flat<CTYPE>(); \
- jboolean iCopied = JNI_FALSE; \
- j##JAVA_DTYPE* values = env->Get##DTYPE##ArrayElements(dst, &iCopied); \
- if (values == nullptr) { \
- return -1; \
- } \
- j##JAVA_DTYPE* value_ptr = values; \
- const int src_size = static_cast<int>(tensor_mapped.size()); \
- const int dst_size = static_cast<int>(env->GetArrayLength(dst)); \
- CHECK_GE(dst_size, src_size) \
- << "dst array must have length >= src Tensor's flattened size."; \
- const int num_items = std::min(src_size, dst_size); \
- for (int i = 0; i < num_items; ++i) { \
- *value_ptr++ = tensor_mapped(i); \
- } \
- env->Release##DTYPE##ArrayElements(dst, values, 0); \
- return 0; \
- }
-
-#define READ_NODE_NIO_BUFFER_METHOD(DTYPE, CTYPE) \
- READ_NODE_NIO_BUFFER_SIGNATURE(DTYPE) { \
- SessionVariables* vars = GetSessionVars(env, thiz); \
- Tensor* t = GetTensor(env, thiz, node_name); \
- if (t == nullptr) { \
- return -1; \
- } \
- auto tensor_mapped = t->flat<CTYPE>(); \
- CTYPE* values = \
- reinterpret_cast<CTYPE*>(env->GetDirectBufferAddress(dst_buffer)); \
- if (values == nullptr) { \
- return -1; \
- } \
- CTYPE* value_ptr = values; \
- const int src_size = static_cast<int>(tensor_mapped.size()); \
- const int dst_size = \
- static_cast<int>(env->GetDirectBufferCapacity(dst_buffer)); \
- CHECK_GE(dst_size, src_size) \
- << "dst buffer must have capacity >= src Tensor's flattened size."; \
- const int num_items = std::min(src_size, dst_size); \
- for (int i = 0; i < num_items; ++i) { \
- *value_ptr++ = tensor_mapped(i); \
- } \
- return 0; \
- }
-
-FILL_NODE_METHOD(Float, float, float, tensorflow::DT_FLOAT)
-FILL_NODE_METHOD(Int, int, int, tensorflow::DT_INT32)
-FILL_NODE_METHOD(Double, double, double, tensorflow::DT_DOUBLE)
-FILL_NODE_METHOD(Byte, byte, uint8_t, tensorflow::DT_UINT8)
-
-FILL_NODE_NIO_BUFFER_METHOD(Float, float, tensorflow::DT_FLOAT)
-FILL_NODE_NIO_BUFFER_METHOD(Int, int, tensorflow::DT_INT32)
-FILL_NODE_NIO_BUFFER_METHOD(Double, double, tensorflow::DT_DOUBLE)
-FILL_NODE_NIO_BUFFER_METHOD(Byte, uint8_t, tensorflow::DT_UINT8)
-
-READ_NODE_METHOD(Float, float, float)
-READ_NODE_METHOD(Int, int, int)
-READ_NODE_METHOD(Double, double, double)
-READ_NODE_METHOD(Byte, byte, uint8_t)
-
-READ_NODE_NIO_BUFFER_METHOD(Float, float);
-READ_NODE_NIO_BUFFER_METHOD(Int, int);
-READ_NODE_NIO_BUFFER_METHOD(Double, double);
-READ_NODE_NIO_BUFFER_METHOD(Byte, uint8_t);
diff --git a/tensorflow/contrib/android/jni/tensorflow_inference_jni.h b/tensorflow/contrib/android/jni/tensorflow_inference_jni.h
deleted file mode 100644
index 93fb8ba315..0000000000
--- a/tensorflow/contrib/android/jni/tensorflow_inference_jni.h
+++ /dev/null
@@ -1,91 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// The methods are exposed to Java to allow for interaction with the native
-// TensorFlow code. See
-// tensorflow/examples/android/src/org/tensorflow/TensorFlowClassifier.java
-// for the Java counterparts.
-
-#ifndef ORG_TENSORFLOW_JNI_TENSORFLOW_JNI_H_ // NOLINT
-#define ORG_TENSORFLOW_JNI_TENSORFLOW_JNI_H_ // NOLINT
-
-#include <jni.h>
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-#define TENSORFLOW_METHOD(METHOD_NAME) \
- Java_org_tensorflow_contrib_android_TensorFlowInferenceInterface_##METHOD_NAME // NOLINT
-
-#define FILL_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
- JNIEXPORT void TENSORFLOW_METHOD(fillNode##DTYPE)( \
- JNIEnv * env, jobject thiz, jstring node_name, jintArray dims, \
- j##JAVA_DTYPE##Array src)
-
-#define FILL_NODE_NIO_BUFFER_SIGNATURE(DTYPE) \
- JNIEXPORT void TENSORFLOW_METHOD(fillNodeFrom##DTYPE##Buffer)( \
- JNIEnv * env, jobject thiz, jstring node_name, jobject dims_buffer, \
- jobject src_buffer)
-
-#define READ_NODE_SIGNATURE(DTYPE, JAVA_DTYPE) \
- JNIEXPORT jint TENSORFLOW_METHOD(readNode##DTYPE)( \
- JNIEnv * env, jobject thiz, jstring node_name, j##JAVA_DTYPE##Array dst)
-
-#define READ_NODE_NIO_BUFFER_SIGNATURE(DTYPE) \
- JNIEXPORT jint TENSORFLOW_METHOD(readNodeInto##DTYPE##Buffer)( \
- JNIEnv * env, jobject thiz, jstring node_name, jobject dst_buffer)
-
-JNIEXPORT void JNICALL TENSORFLOW_METHOD(testLoaded)(JNIEnv* env, jobject thiz);
-
-JNIEXPORT jint JNICALL TENSORFLOW_METHOD(initializeTensorFlow)(
- JNIEnv* env, jobject thiz, jobject java_asset_manager, jstring model);
-
-JNIEXPORT jint JNICALL TENSORFLOW_METHOD(runInference)(
- JNIEnv* env, jobject thiz, jobjectArray output_name_strings);
-
-JNIEXPORT void JNICALL TENSORFLOW_METHOD(enableStatLogging)(
- JNIEnv* env, jobject thiz, jboolean enableStatLogging);
-
-JNIEXPORT jstring JNICALL TENSORFLOW_METHOD(getStatString)(JNIEnv* env,
- jobject thiz);
-
-JNIEXPORT jint JNICALL TENSORFLOW_METHOD(close)(JNIEnv* env, jobject thiz);
-
-FILL_NODE_SIGNATURE(Float, float);
-FILL_NODE_SIGNATURE(Int, int);
-FILL_NODE_SIGNATURE(Double, double);
-FILL_NODE_SIGNATURE(Byte, byte);
-
-FILL_NODE_NIO_BUFFER_SIGNATURE(Float);
-FILL_NODE_NIO_BUFFER_SIGNATURE(Int);
-FILL_NODE_NIO_BUFFER_SIGNATURE(Double);
-FILL_NODE_NIO_BUFFER_SIGNATURE(Byte);
-
-READ_NODE_SIGNATURE(Float, float);
-READ_NODE_SIGNATURE(Int, int);
-READ_NODE_SIGNATURE(Double, double);
-READ_NODE_SIGNATURE(Byte, byte);
-
-READ_NODE_NIO_BUFFER_SIGNATURE(Float);
-READ_NODE_NIO_BUFFER_SIGNATURE(Int);
-READ_NODE_NIO_BUFFER_SIGNATURE(Double);
-READ_NODE_NIO_BUFFER_SIGNATURE(Byte);
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
-
-#endif // ORG_TENSORFLOW_JNI_TENSORFLOW_JNI_H_ // NOLINT