diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Tensor.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Tensor.java | 122 |
1 files changed, 79 insertions, 43 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 4424100390..c5ad1ee51c 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -25,6 +25,7 @@ import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; import java.util.Arrays; +import java.util.HashMap; /** * A typed multi-dimensional array. @@ -97,9 +98,19 @@ public final class Tensor implements AutoCloseable { * using {@link #create(DataType, long[], ByteBuffer)} instead. */ public static Tensor create(Object obj) { + return create(obj, dataTypeOf(obj)); + } + + /** + * Create a Tensor of data type {@code dtype} from a Java object. + * + * @param dtype the intended tensor data type. It must match the the run-time type of the object. + */ + static Tensor create(Object obj, DataType dtype) { Tensor t = new Tensor(); - t.dtype = dataTypeOf(obj); - t.shapeCopy = new long[numDimensions(obj)]; + t.dtype = dtype; + t.shapeCopy = new long[numDimensions(obj, dtype)]; + assert objectCompatWithType(obj, dtype); fillShape(obj, 0, t.shapeCopy); if (t.dtype != DataType.STRING) { int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy); @@ -190,8 +201,7 @@ public final class Tensor implements AutoCloseable { * * <p>Creates a Tensor with the provided shape of any type where the tensor's data has been * encoded into {@code data} as per the specification of the TensorFlow <a - * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C - * API</a>. + * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>. * * @param dataType the tensor datatype. * @param shape the tensor shape. @@ -537,56 +547,70 @@ public final class Tensor implements AutoCloseable { } } + private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>(); + + static { + classDataTypes.put(int.class, DataType.INT32); + classDataTypes.put(Integer.class, DataType.INT32); + classDataTypes.put(long.class, DataType.INT64); + classDataTypes.put(Long.class, DataType.INT64); + classDataTypes.put(float.class, DataType.FLOAT); + classDataTypes.put(Float.class, DataType.FLOAT); + classDataTypes.put(double.class, DataType.DOUBLE); + classDataTypes.put(Double.class, DataType.DOUBLE); + classDataTypes.put(byte.class, DataType.STRING); + classDataTypes.put(Byte.class, DataType.STRING); + classDataTypes.put(boolean.class, DataType.BOOL); + classDataTypes.put(Boolean.class, DataType.BOOL); + } + private static DataType dataTypeOf(Object o) { - if (o.getClass().isArray()) { - if (Array.getLength(o) == 0) { - throw new IllegalArgumentException("cannot create Tensors with a 0 dimension"); - } - // byte[] is a DataType.STRING scalar. - Object e = Array.get(o, 0); - if (e == null) { - throwExceptionIfNotByteOfByteArrays(o); - return DataType.STRING; - } - if (Byte.class.isInstance(e) || byte.class.isInstance(e)) { - return DataType.STRING; - } - return dataTypeOf(e); + Class<?> c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); } - if (Float.class.isInstance(o) || float.class.isInstance(o)) { - return DataType.FLOAT; - } else if (Double.class.isInstance(o) || double.class.isInstance(o)) { - return DataType.DOUBLE; - } else if (Integer.class.isInstance(o) || int.class.isInstance(o)) { - return DataType.INT32; - } else if (Long.class.isInstance(o) || long.class.isInstance(o)) { - return DataType.INT64; - } else if (Boolean.class.isInstance(o) || boolean.class.isInstance(o)) { - return DataType.BOOL; - } else { - throw new IllegalArgumentException("cannot create Tensors of " + o.getClass().getName()); + DataType ret = classDataTypes.get(c); + if (ret != null) { + return ret; } + throw new IllegalArgumentException("cannot create Tensors of type " + c.getName()); } - private static int numDimensions(Object o) { - if (o.getClass().isArray()) { - Object e = Array.get(o, 0); - if (e == null) { - throwExceptionIfNotByteOfByteArrays(o); - return 1; - } else if (Byte.class.isInstance(e) || byte.class.isInstance(e)) { - return 0; - } - return 1 + numDimensions(e); + /** + * Returns the number of dimensions of a tensor of type dtype when represented by the object o. + */ + private static int numDimensions(Object o, DataType dtype) { + int ret = numArrayDimensions(o); + if (dtype == DataType.STRING && ret > 0) { + return ret - 1; } - return 0; + return ret; } + /** Returns the number of dimensions of the array object o. Returns 0 if o is not an array. */ + private static int numArrayDimensions(Object o) { + Class<?> c = o.getClass(); + int i = 0; + while (c.isArray()) { + c = c.getComponentType(); + i++; + } + return i; + } + + /** + * Fills in the remaining entries in the shape array starting from position {@code dim} with the + * dimension sizes of the multidimensional array o. Checks that all arrays reachable from o have + * sizes consistent with the filled-in shape, throwing IllegalArgumentException otherwise. + */ private static void fillShape(Object o, int dim, long[] shape) { if (shape == null || dim == shape.length) { return; } final int len = Array.getLength(o); + if (len == 0) { + throw new IllegalArgumentException("cannot create Tensors with a 0 dimension"); + } if (shape[dim] == 0) { shape[dim] = len; } else if (shape[dim] != len) { @@ -598,15 +622,27 @@ public final class Tensor implements AutoCloseable { } } + /** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */ + private static boolean objectCompatWithType(Object obj, DataType dtype) { + DataType dto = dataTypeOf(obj); + if (dto.equals(dtype)) { + return true; + } + if (dto == DataType.STRING && dtype == DataType.UINT8) { + return true; + } + return false; + } + private void throwExceptionIfTypeIsIncompatible(Object o) { final int rank = numDimensions(); - final int oRank = numDimensions(o); + final int oRank = numDimensions(o, dtype); if (oRank != rank) { throw new IllegalArgumentException( String.format( "cannot copy Tensor with %d dimensions into an object with %d", rank, oRank)); } - if (dataTypeOf(o) != dtype) { + if (!objectCompatWithType(o, dtype)) { throw new IllegalArgumentException( String.format( "cannot copy Tensor with DataType %s into an object of type %s", |