aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/DataType.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/DataType.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/DataType.java35
1 files changed, 28 insertions, 7 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
index 7b92be6d38..ded09974a4 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
@@ -17,21 +17,22 @@ package org.tensorflow;
import java.util.HashMap;
import java.util.Map;
+
import org.tensorflow.types.UInt8;
/** Represents the type of elements in a {@link Tensor} as an enum. */
public enum DataType {
/** 32-bit single precision floating point. */
- FLOAT(1),
+ FLOAT(1, 4),
/** 64-bit double precision floating point. */
- DOUBLE(2),
+ DOUBLE(2, 8),
/** 32-bit signed integer. */
- INT32(3),
+ INT32(3, 4),
/** 8-bit unsigned integer. */
- UINT8(4),
+ UINT8(4, 1),
/**
* A sequence of bytes.
@@ -41,16 +42,36 @@ public enum DataType {
STRING(7),
/** 64-bit signed integer. */
- INT64(9),
+ INT64(9, 8),
/** Boolean. */
- BOOL(10);
+ BOOL(10, 1);
private final int value;
+
+ private final int sizeInBytes;
- // The integer value must match the corresponding TF_* value in the TensorFlow C API.
+ /**
+ * @param value must match the corresponding TF_* value in the TensorFlow C API.
+ */
DataType(int value) {
+ this(value, -1);
+ }
+
+ /**
+ * @param value must match the corresponding TF_* value in the TensorFlow C API.
+ * @param sizeInBytes size of an element of this type, in bytes, -1 if unknown
+ */
+ DataType(int value, int sizeInBytes) {
this.value = value;
+ this.sizeInBytes = sizeInBytes;
+ }
+
+ /**
+ * @return size of an element of this type, in bytes, or -1 if element size is variable
+ */
+ public int sizeInBytes() {
+ return sizeInBytes;
}
/** Corresponding value of the TF_DataType enum in the TensorFlow C API. */