aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/native/tensor_jni.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/native/tensor_jni.cc')
-rw-r--r--tensorflow/java/src/main/native/tensor_jni.cc6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc
index 7bfe6c896d..745abec244 100644
--- a/tensorflow/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/java/src/main/native/tensor_jni.cc
@@ -41,8 +41,11 @@ size_t elemByteSize(TF_DataType dtype) {
// have the same byte sizes. Validate that:
switch (dtype) {
case TF_BOOL:
+ case TF_UINT8:
static_assert(sizeof(jboolean) == 1,
"Java boolean not compatible with TF_BOOL");
+ static_assert(sizeof(jbyte) == 1,
+ "Java byte not compatible with TF_UINT8");
return 1;
case TF_FLOAT:
case TF_INT32:
@@ -90,6 +93,7 @@ void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
CASE(TF_INT32, jint, "intValue", "()I", Int);
CASE(TF_INT64, jlong, "longValue", "()J", Long);
+ CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
#undef CASE
case TF_BOOL: {
jclass clazz = env->FindClass("java/lang/Boolean");
@@ -134,6 +138,7 @@ size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
CASE(TF_INT32, jint, Int);
CASE(TF_INT64, jlong, Long);
CASE(TF_BOOL, jboolean, Boolean);
+ CASE(TF_UINT8, jbyte, Byte);
#undef CASE
default:
throwException(env, kIllegalStateException, "invalid DataType(%d)",
@@ -168,6 +173,7 @@ size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
CASE(TF_INT32, jint, Int);
CASE(TF_INT64, jlong, Long);
CASE(TF_BOOL, jboolean, Boolean);
+ CASE(TF_UINT8, jbyte, Byte);
#undef CASE
default:
throwException(env, kIllegalStateException, "invalid DataType(%d)",