diff options
Diffstat (limited to 'tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc')
-rw-r--r-- | tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc | 131 |
1 files changed, 89 insertions, 42 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 9e9387da86..7ff96a3172 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { return reinterpret_cast<TfLiteTensor*>(handle); } +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Interal error: Java float not compatible with " + "kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Interal error: Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Interal error: Java byte not compatible with " + "kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Interal error: Java long not compatible with " + "kTfLiteInt64"); + return 8; + default: + return 0; + } +} + size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, void* dst, size_t dst_size) { jarray array = static_cast<jarray>(object); @@ -141,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, } } -} // namespace - -size_t elementByteSize(TfLiteType data_type) { - // The code in this file makes the assumption that the - // TensorFlow TF_DataTypes and the Java primitive types - // have the same byte sizes. Validate that: - switch (data_type) { - case kTfLiteFloat32: - static_assert(sizeof(jfloat) == 4, - "Interal error: Java float not compatible with " - "kTfLiteFloat"); - return 4; - case kTfLiteInt32: - static_assert(sizeof(jint) == 4, - "Interal error: Java int not compatible with kTfLiteInt"); - return 4; - case kTfLiteUInt8: - static_assert(sizeof(jbyte) == 1, - "Interal error: Java byte not compatible with " - "kTfLiteUInt8"); - return 1; - case kTfLiteInt64: - static_assert(sizeof(jlong) == 8, - "Interal error: Java long not compatible with " - "kTfLiteInt64"); - return 8; - default: - return 0; - } -} - -size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { - char* buf = static_cast<char*>(env->GetDirectBufferAddress(object)); - if (!buf) { - throwException(env, kIllegalArgumentException, - "Input ByteBuffer is not a direct buffer"); - return 0; - } - *dst = buf; - return dst_size; -} - size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, int dims_left, char** dst, int dst_size) { if (dims_left <= 1) { @@ -203,6 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, } } +} // namespace + +JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Tensor hasn't been allocated."); + return nullptr; + } + return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw), + static_cast<jlong>(tensor->bytes)); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer( + JNIEnv* env, jclass clazz, jlong handle, jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + + char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src)); + if (!src_data_raw) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return; + } + + tensor->data.raw = src_data_raw; +} + JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, jclass clazz, @@ -220,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, num_dims, static_cast<jarray>(value)); } +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject src) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + if (tensor->data.raw == nullptr) { + throwException(env, kIllegalArgumentException, + "Internal error: Target Tensor hasn't been allocated."); + return; + } + if (tensor->dims->size == 0) { + throwException(env, kIllegalArgumentException, + "Internal error: Cannot copy empty/scalar Tensors."); + return; + } + writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size, + &tensor->data.raw, tensor->bytes); +} + JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, jclass clazz, jlong handle) { @@ -237,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data); return result; } + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env, + jclass clazz, + jlong handle) { + const TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast<jint>(tensor->bytes); +} |