diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java | 93 |
1 files changed, 53 insertions, 40 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index bcf165346c..a3667dfd6e 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -38,22 +38,21 @@ import org.tensorflow.op.annotation.Operator; public final class Constant<T> extends PrimitiveOp implements Operand<T> { /** - * Create a constant from a Java object. - * - * <p>The argument {@code object} is first converted into a Tensor using {@link - * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be - * provided. For example: + * Create a {@link DataType#INT32} constant with data from the given buffer. * - * <pre>{@code - * Constant.create(scope, 7); // returns a constant scalar tensor 7 - * }</pre> + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. * * @param scope is a scope used to add the underlying operation. - * @param object a Java object representing the constant. - * @see org.tensorflow.Tensor#create(Object) Tensor.create + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @return an integer constant + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) { - try (Tensor<T> value = Tensor.create(object, type)) { + public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) { + try (Tensor<Integer> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -63,6 +62,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return an integer constant */ public static Constant<Integer> create(Scope scope, int data) { try (Tensor<Integer> value = Tensors.create(data)) { @@ -71,7 +71,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** - * Create a {@link DataType#INT32} constant with data from the given buffer. + * Create a {@link DataType#FLOAT} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents @@ -81,10 +81,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a float constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) { - try (Tensor<Integer> value = Tensor.create(shape, data)) { + public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) { + try (Tensor<Float> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -94,6 +95,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a float constant */ public static Constant<Float> create(Scope scope, float data) { try (Tensor<Float> value = Tensors.create(data)) { @@ -102,7 +104,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** - * Create a {@link DataType#FLOAT} constant with data from the given buffer. + * Create a {@link DataType#DOUBLE} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents @@ -112,10 +114,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a double constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) { - try (Tensor<Float> value = Tensor.create(shape, data)) { + public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) { + try (Tensor<Double> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -125,6 +128,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a double constant */ public static Constant<Double> create(Scope scope, double data) { try (Tensor<Double> value = Tensors.create(data)) { @@ -133,7 +137,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** - * Create a {@link DataType#DOUBLE} constant with data from the given buffer. + * Create a {@link DataType#INT64} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents @@ -143,10 +147,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) { - try (Tensor<Double> value = Tensor.create(shape, data)) { + public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) { + try (Tensor<Long> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -156,6 +161,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a long constant */ public static Constant<Long> create(Scope scope, long data) { try (Tensor<Long> value = Tensors.create(data)) { @@ -164,29 +170,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** - * Create a {@link DataType#INT64} constant with data from the given buffer. - * - * <p>Creates a constant with the given shape by copying elements from the buffer (starting from - * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents - * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this - * method. - * - * @param scope is a scope used to add the underlying operation. - * @param shape the tensor shape. - * @param data a buffer containing the tensor data. - * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer - */ - public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) { - try (Tensor<Long> value = Tensor.create(shape, data)) { - return createWithTensor(scope, value); - } - } - - /** * Creates a constant containing a single {@code boolean} element. * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a boolean constant */ public static Constant<Boolean> create(Scope scope, boolean data) { try (Tensor<Boolean> value = Tensors.create(data)) { @@ -199,6 +187,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * * @param scope is a scope used to add the underlying operation. * @param data The string to put into the new constant. + * @return a string constant */ public static Constant<String> create(Scope scope, String data) { try (Tensor<String> value = Tensors.create(data)) { @@ -212,6 +201,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param charset The encoding from String to bytes. * @param data The string to put into the new constant. + * @return a string constant */ public static Constant<String> create(Scope scope, String data, Charset charset) { try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) { @@ -231,6 +221,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param type the tensor datatype. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a constant of type `type` * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ @@ -240,6 +231,28 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } } + /** + * Create a constant from a Java object. + * + * <p>The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + * <pre>{@code + * Constant.create(scope, 7, Integer.class); // returns a constant scalar tensor 7 + * }</pre> + * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @return a constant of type `type` + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) { + try (Tensor<T> value = Tensor.create(object, type)) { + return createWithTensor(scope, value); + } + } + private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) { return new Constant<T>( scope |