diff options
Diffstat (limited to 'tensorflow/java/src/main/native/tensor_jni.cc')
-rw-r--r-- | tensorflow/java/src/main/native/tensor_jni.cc | 6 |
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc index 7bfe6c896d..745abec244 100644 --- a/tensorflow/java/src/main/native/tensor_jni.cc +++ b/tensorflow/java/src/main/native/tensor_jni.cc @@ -41,8 +41,11 @@ size_t elemByteSize(TF_DataType dtype) { // have the same byte sizes. Validate that: switch (dtype) { case TF_BOOL: + case TF_UINT8: static_assert(sizeof(jboolean) == 1, "Java boolean not compatible with TF_BOOL"); + static_assert(sizeof(jbyte) == 1, + "Java byte not compatible with TF_UINT8"); return 1; case TF_FLOAT: case TF_INT32: @@ -90,6 +93,7 @@ void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst, CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double); CASE(TF_INT32, jint, "intValue", "()I", Int); CASE(TF_INT64, jlong, "longValue", "()J", Long); + CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte); #undef CASE case TF_BOOL: { jclass clazz = env->FindClass("java/lang/Boolean"); @@ -134,6 +138,7 @@ size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst, CASE(TF_INT32, jint, Int); CASE(TF_INT64, jlong, Long); CASE(TF_BOOL, jboolean, Boolean); + CASE(TF_UINT8, jbyte, Byte); #undef CASE default: throwException(env, kIllegalStateException, "invalid DataType(%d)", @@ -168,6 +173,7 @@ size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src, CASE(TF_INT32, jint, Int); CASE(TF_INT64, jlong, Long); CASE(TF_BOOL, jboolean, Boolean); + CASE(TF_UINT8, jbyte, Byte); #undef CASE default: throwException(env, kIllegalStateException, "invalid DataType(%d)", |