diff options
Diffstat (limited to 'tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc')
-rw-r--r-- | tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc | 446 |
1 files changed, 446 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc new file mode 100644 index 0000000000..bc6462eb54 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -0,0 +1,446 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" + +namespace { + +const int kByteBufferValue = 999; +const int kBufferSize = 256; + +tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to Interpreter."); + return nullptr; + } + return reinterpret_cast<tflite::Interpreter*>(handle); +} + +tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, "Invalid handle to model."); + return nullptr; + } + return reinterpret_cast<tflite::FlatBufferModel*>(handle); +} + +BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to ErrorReporter."); + return nullptr; + } + return reinterpret_cast<BufferErrorReporter*>(handle); +} + +std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { + int size = static_cast<int>(env->GetArrayLength(inputs)); + std::vector<int> outputs(size, 0); + jint* ptr = env->GetIntArrayElements(inputs, nullptr); + if (ptr == nullptr) { + throwException(env, kIllegalArgumentException, + "Empty dimensions of input array."); + return {}; + } + for (int i = 0; i < size; ++i) { + outputs[i] = ptr[i]; + } + env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT); + return outputs; +} + +bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } + +TfLiteType resolveDataType(jint data_type) { + switch (data_type) { + case 1: + return kTfLiteFloat32; + case 2: + return kTfLiteInt32; + case 3: + return kTfLiteUInt8; + case 4: + return kTfLiteInt64; + default: + return kTfLiteNoType; + } +} + +void printDims(char* buffer, int max_size, int* dims, int num_dims) { + if (max_size <= 0) return; + buffer[0] = '?'; + int size = 1; + for (int i = 1; i < num_dims; ++i) { + if (max_size > size) { + int written_size = + snprintf(buffer + size, max_size - size, ",%d", dims[i]); + if (written_size < 0) return; + size += written_size; + } + } +} + +TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, + const int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values, + jobjectArray sizes) { + if (input_size != interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Expected num of inputs is %d but got %d", + interpreter->inputs().size(), input_size); + return kTfLiteError; + } + if (input_size != env->GetArrayLength(data_types) || + input_size != env->GetArrayLength(nums_of_bytes) || + input_size != env->GetArrayLength(values)) { + throwException(env, kIllegalArgumentException, + "Arrays in arguments should be of the same length, but got " + "%d sizes, %d data_types, %d nums_of_bytes, and %d values", + input_size, env->GetArrayLength(data_types), + env->GetArrayLength(nums_of_bytes), + env->GetArrayLength(values)); + return kTfLiteError; + } + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jintArray dims = + static_cast<jintArray>(env->GetObjectArrayElement(sizes, i)); + int num_dims = static_cast<int>(env->GetArrayLength(dims)); + if (target->dims->size != num_dims) { + throwException(env, kIllegalArgumentException, + "%d-th input should have %d dimensions, but found %d " + "dimensions", + i, target->dims->size, num_dims); + return kTfLiteError; + } + jint* ptr = env->GetIntArrayElements(dims, nullptr); + for (int j = 1; j < num_dims; ++j) { + if (target->dims->data[j] != ptr[j]) { + std::unique_ptr<char[]> expected_dims(new char[kBufferSize]); + std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]); + printDims(expected_dims.get(), kBufferSize, target->dims->data, + num_dims); + printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); + throwException(env, kIllegalArgumentException, + "%d-th input dimension should be [%s], but found [%s]", + i, expected_dims.get(), obtained_dims.get()); + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + return kTfLiteError; + } + } + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jobjectArray sizes) { + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + jintArray dims = + static_cast<jintArray>(env->GetObjectArrayElement(sizes, i)); + TfLiteStatus status = interpreter->ResizeInputTensor( + input_idx, convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + return status; + } + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values) { + jint* data_type = env->GetIntArrayElements(data_types, nullptr); + jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jobject value = env->GetObjectArrayElement(values, i); + bool is_byte_buffer = isByteBuffer(data_type[i]); + if (is_byte_buffer) { + writeByteBuffer(env, value, &(target->data.raw), + static_cast<int>(num_bytes[i])); + } else { + TfLiteType type = resolveDataType(data_type[i]); + if (type != target->type) { + throwException(env, kIllegalArgumentException, + "DataType (%d) of input data does not match with the " + "DataType (%d) of model inputs.", + type, target->type); + return kTfLiteError; + } + writeMultiDimensionalArray(env, value, target->type, target->dims->size, + &(target->data.raw), + static_cast<int>(num_bytes[i])); + } + env->DeleteLocalRef(value); + if (env->ExceptionCheck()) return kTfLiteError; + } + env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); + env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); + return kTfLiteOk; +} + +} // namespace + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get input names."); + return nullptr; + } + size_t size = interpreter->inputs().size(); + jobjectArray names = static_cast<jobjectArray>( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement(names, i, + env->NewStringUTF(interpreter->GetInputName(i))); + } + return names; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get output names."); + return nullptr; + } + size_t size = interpreter->outputs().size(); + jobjectArray names = static_cast<jobjectArray>( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement( + names, i, env->NewStringUTF(interpreter->GetOutputName(i))); + } + return names; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->UseNNAPI(static_cast<bool>(state)); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size) { + BufferErrorReporter* error_reporter = + new BufferErrorReporter(env, static_cast<int>(size)); + return reinterpret_cast<jlong>(error_reporter); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* path = env->GetStringUTFChars(model_file, nullptr); + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "Contents of %s does not encode a valid TensorFlowLite " + "model: %s", + path, error_reporter->CachedErrorMessage()); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + env->ReleaseStringUTFChars(model_file, path); + return reinterpret_cast<jlong>(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* buf = + static_cast<char*>(env->GetDirectBufferAddress(model_buffer)); + jlong capacity = env->GetDirectBufferCapacity(model_buffer); + auto model = tflite::FlatBufferModel::BuildFromBuffer( + buf, static_cast<size_t>(capacity), error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer does not encode a valid TensorFlowLite " + "model: %s", + error_reporter->CachedErrorMessage()); + return 0; + } + return reinterpret_cast<jlong>(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle) { + tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); + if (model == nullptr) return 0; + auto resolver = ::tflite::CreateOpResolver(); + std::unique_ptr<tflite::Interpreter> interpreter; + tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + return reinterpret_cast<jlong>(interpreter.release()); +} + +// Sets inputs, runs inference, and returns outputs as long handles. +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values) { + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return nullptr; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return nullptr; + const int input_size = env->GetArrayLength(sizes); + // validates inputs + TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, + nums_of_bytes, values, sizes); + if (status != kTfLiteOk) return nullptr; + // resizes inputs + status = resizeInputs(env, interpreter, input_size, sizes); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, "Can not resize the input: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // allocates memory + status = interpreter->AllocateTensors(); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, + "Can not allocate memory for the given inputs: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // sets inputs + status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, + values); + if (status != kTfLiteOk) return nullptr; + // runs inference + if (interpreter->Invoke() != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to run on the given Interpreter: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // returns outputs + const std::vector<int>& results = interpreter->outputs(); + if (results.empty()) { + throwException(env, kIllegalArgumentException, + "The Interpreter does not have any outputs."); + return nullptr; + } + jlongArray outputs = env->NewLongArray(results.size()); + size_t size = results.size(); + for (int i = 0; i < size; ++i) { + TfLiteTensor* source = interpreter->tensor(results[i]); + jlong output = reinterpret_cast<jlong>(source); + env->SetLongArrayRegion(outputs, i, 1, &output); + } + return outputs; +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + const int idx = static_cast<int>(input_idx); + if (input_idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Out of range: Failed to get %d-th input out of %d inputs", + input_idx, interpreter->inputs().size()); + return nullptr; + } + TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); + int size = target->dims->size; + int expected_num_bytes = elementByteSize(target->type); + for (int i = 0; i < size; ++i) { + expected_num_bytes *= target->dims->data[i]; + } + if (num_bytes != expected_num_bytes) { + throwException(env, kIllegalArgumentException, + "Failed to get input dimensions. %d-th input should have" + " %d bytes, but found %d bytes.", + idx, expected_num_bytes, num_bytes); + return nullptr; + } + jintArray outputs = env->NewIntArray(size); + env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); + return outputs; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return; + const int idx = static_cast<int>(input_idx); + if (idx < 0 || idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Can not resize %d-th input for a model having %d inputs.", + idx, interpreter->inputs().size()); + } + TfLiteStatus status = interpreter->ResizeInputTensor( + interpreter->inputs()[idx], convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to resize %d-th input: %s", idx, + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle) { + if (interpreter_handle != 0) { + delete convertLongToInterpreter(env, interpreter_handle); + } + if (model_handle != 0) { + delete convertLongToModel(env, model_handle); + } + if (error_handle != 0) { + delete convertLongToErrorReporter(env, error_handle); + } +} |