aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Tensor.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java122
1 files changed, 79 insertions, 43 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 4424100390..c5ad1ee51c 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -25,6 +25,7 @@ import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
+import java.util.HashMap;
/**
* A typed multi-dimensional array.
@@ -97,9 +98,19 @@ public final class Tensor implements AutoCloseable {
* using {@link #create(DataType, long[], ByteBuffer)} instead.
*/
public static Tensor create(Object obj) {
+ return create(obj, dataTypeOf(obj));
+ }
+
+ /**
+ * Create a Tensor of data type {@code dtype} from a Java object.
+ *
+ * @param dtype the intended tensor data type. It must match the the run-time type of the object.
+ */
+ static Tensor create(Object obj, DataType dtype) {
Tensor t = new Tensor();
- t.dtype = dataTypeOf(obj);
- t.shapeCopy = new long[numDimensions(obj)];
+ t.dtype = dtype;
+ t.shapeCopy = new long[numDimensions(obj, dtype)];
+ assert objectCompatWithType(obj, dtype);
fillShape(obj, 0, t.shapeCopy);
if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
@@ -190,8 +201,7 @@ public final class Tensor implements AutoCloseable {
*
* <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
* encoded into {@code data} as per the specification of the TensorFlow <a
- * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C
- * API</a>.
+ * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
*
* @param dataType the tensor datatype.
* @param shape the tensor shape.
@@ -537,56 +547,70 @@ public final class Tensor implements AutoCloseable {
}
}
+ private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
+
+ static {
+ classDataTypes.put(int.class, DataType.INT32);
+ classDataTypes.put(Integer.class, DataType.INT32);
+ classDataTypes.put(long.class, DataType.INT64);
+ classDataTypes.put(Long.class, DataType.INT64);
+ classDataTypes.put(float.class, DataType.FLOAT);
+ classDataTypes.put(Float.class, DataType.FLOAT);
+ classDataTypes.put(double.class, DataType.DOUBLE);
+ classDataTypes.put(Double.class, DataType.DOUBLE);
+ classDataTypes.put(byte.class, DataType.STRING);
+ classDataTypes.put(Byte.class, DataType.STRING);
+ classDataTypes.put(boolean.class, DataType.BOOL);
+ classDataTypes.put(Boolean.class, DataType.BOOL);
+ }
+
private static DataType dataTypeOf(Object o) {
- if (o.getClass().isArray()) {
- if (Array.getLength(o) == 0) {
- throw new IllegalArgumentException("cannot create Tensors with a 0 dimension");
- }
- // byte[] is a DataType.STRING scalar.
- Object e = Array.get(o, 0);
- if (e == null) {
- throwExceptionIfNotByteOfByteArrays(o);
- return DataType.STRING;
- }
- if (Byte.class.isInstance(e) || byte.class.isInstance(e)) {
- return DataType.STRING;
- }
- return dataTypeOf(e);
+ Class<?> c = o.getClass();
+ while (c.isArray()) {
+ c = c.getComponentType();
}
- if (Float.class.isInstance(o) || float.class.isInstance(o)) {
- return DataType.FLOAT;
- } else if (Double.class.isInstance(o) || double.class.isInstance(o)) {
- return DataType.DOUBLE;
- } else if (Integer.class.isInstance(o) || int.class.isInstance(o)) {
- return DataType.INT32;
- } else if (Long.class.isInstance(o) || long.class.isInstance(o)) {
- return DataType.INT64;
- } else if (Boolean.class.isInstance(o) || boolean.class.isInstance(o)) {
- return DataType.BOOL;
- } else {
- throw new IllegalArgumentException("cannot create Tensors of " + o.getClass().getName());
+ DataType ret = classDataTypes.get(c);
+ if (ret != null) {
+ return ret;
}
+ throw new IllegalArgumentException("cannot create Tensors of type " + c.getName());
}
- private static int numDimensions(Object o) {
- if (o.getClass().isArray()) {
- Object e = Array.get(o, 0);
- if (e == null) {
- throwExceptionIfNotByteOfByteArrays(o);
- return 1;
- } else if (Byte.class.isInstance(e) || byte.class.isInstance(e)) {
- return 0;
- }
- return 1 + numDimensions(e);
+ /**
+ * Returns the number of dimensions of a tensor of type dtype when represented by the object o.
+ */
+ private static int numDimensions(Object o, DataType dtype) {
+ int ret = numArrayDimensions(o);
+ if (dtype == DataType.STRING && ret > 0) {
+ return ret - 1;
}
- return 0;
+ return ret;
}
+ /** Returns the number of dimensions of the array object o. Returns 0 if o is not an array. */
+ private static int numArrayDimensions(Object o) {
+ Class<?> c = o.getClass();
+ int i = 0;
+ while (c.isArray()) {
+ c = c.getComponentType();
+ i++;
+ }
+ return i;
+ }
+
+ /**
+ * Fills in the remaining entries in the shape array starting from position {@code dim} with the
+ * dimension sizes of the multidimensional array o. Checks that all arrays reachable from o have
+ * sizes consistent with the filled-in shape, throwing IllegalArgumentException otherwise.
+ */
private static void fillShape(Object o, int dim, long[] shape) {
if (shape == null || dim == shape.length) {
return;
}
final int len = Array.getLength(o);
+ if (len == 0) {
+ throw new IllegalArgumentException("cannot create Tensors with a 0 dimension");
+ }
if (shape[dim] == 0) {
shape[dim] = len;
} else if (shape[dim] != len) {
@@ -598,15 +622,27 @@ public final class Tensor implements AutoCloseable {
}
}
+ /** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */
+ private static boolean objectCompatWithType(Object obj, DataType dtype) {
+ DataType dto = dataTypeOf(obj);
+ if (dto.equals(dtype)) {
+ return true;
+ }
+ if (dto == DataType.STRING && dtype == DataType.UINT8) {
+ return true;
+ }
+ return false;
+ }
+
private void throwExceptionIfTypeIsIncompatible(Object o) {
final int rank = numDimensions();
- final int oRank = numDimensions(o);
+ final int oRank = numDimensions(o, dtype);
if (oRank != rank) {
throw new IllegalArgumentException(
String.format(
"cannot copy Tensor with %d dimensions into an object with %d", rank, oRank));
}
- if (dataTypeOf(o) != dtype) {
+ if (!objectCompatWithType(o, dtype)) {
throw new IllegalArgumentException(
String.format(
"cannot copy Tensor with DataType %s into an object of type %s",