diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/DataType.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/DataType.java | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java index 7b92be6d38..ded09974a4 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java @@ -17,21 +17,22 @@ package org.tensorflow; import java.util.HashMap; import java.util.Map; + import org.tensorflow.types.UInt8; /** Represents the type of elements in a {@link Tensor} as an enum. */ public enum DataType { /** 32-bit single precision floating point. */ - FLOAT(1), + FLOAT(1, 4), /** 64-bit double precision floating point. */ - DOUBLE(2), + DOUBLE(2, 8), /** 32-bit signed integer. */ - INT32(3), + INT32(3, 4), /** 8-bit unsigned integer. */ - UINT8(4), + UINT8(4, 1), /** * A sequence of bytes. @@ -41,16 +42,36 @@ public enum DataType { STRING(7), /** 64-bit signed integer. */ - INT64(9), + INT64(9, 8), /** Boolean. */ - BOOL(10); + BOOL(10, 1); private final int value; + + private final int sizeInBytes; - // The integer value must match the corresponding TF_* value in the TensorFlow C API. + /** + * @param value must match the corresponding TF_* value in the TensorFlow C API. + */ DataType(int value) { + this(value, -1); + } + + /** + * @param value must match the corresponding TF_* value in the TensorFlow C API. + * @param sizeInBytes size of an element of this type, in bytes, -1 if unknown + */ + DataType(int value, int sizeInBytes) { this.value = value; + this.sizeInBytes = sizeInBytes; + } + + /** + * @return size of an element of this type, in bytes, or -1 if element size is variable + */ + public int sizeInBytes() { + return sizeInBytes; } /** Corresponding value of the TF_DataType enum in the TensorFlow C API. */ |