aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java93
1 files changed, 53 insertions, 40 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index bcf165346c..a3667dfd6e 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -38,22 +38,21 @@ import org.tensorflow.op.annotation.Operator;
public final class Constant<T> extends PrimitiveOp implements Operand<T> {
/**
- * Create a constant from a Java object.
- *
- * <p>The argument {@code object} is first converted into a Tensor using {@link
- * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
- * provided. For example:
+ * Create a {@link DataType#INT32} constant with data from the given buffer.
*
- * <pre>{@code
- * Constant.create(scope, 7); // returns a constant scalar tensor 7
- * }</pre>
+ * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
+ * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
+ * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
+ * method.
*
* @param scope is a scope used to add the underlying operation.
- * @param object a Java object representing the constant.
- * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @return an integer constant
+ * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
- try (Tensor<T> value = Tensor.create(object, type)) {
+ public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) {
+ try (Tensor<Integer> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -63,6 +62,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return an integer constant
*/
public static Constant<Integer> create(Scope scope, int data) {
try (Tensor<Integer> value = Tensors.create(data)) {
@@ -71,7 +71,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#INT32} constant with data from the given buffer.
+ * Create a {@link DataType#FLOAT} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
@@ -81,10 +81,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a float constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) {
- try (Tensor<Integer> value = Tensor.create(shape, data)) {
+ public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) {
+ try (Tensor<Float> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -94,6 +95,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a float constant
*/
public static Constant<Float> create(Scope scope, float data) {
try (Tensor<Float> value = Tensors.create(data)) {
@@ -102,7 +104,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#FLOAT} constant with data from the given buffer.
+ * Create a {@link DataType#DOUBLE} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
@@ -112,10 +114,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a double constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) {
- try (Tensor<Float> value = Tensor.create(shape, data)) {
+ public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) {
+ try (Tensor<Double> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -125,6 +128,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a double constant
*/
public static Constant<Double> create(Scope scope, double data) {
try (Tensor<Double> value = Tensors.create(data)) {
@@ -133,7 +137,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#DOUBLE} constant with data from the given buffer.
+ * Create a {@link DataType#INT64} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
@@ -143,10 +147,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a long constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) {
- try (Tensor<Double> value = Tensor.create(shape, data)) {
+ public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) {
+ try (Tensor<Long> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -156,6 +161,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a long constant
*/
public static Constant<Long> create(Scope scope, long data) {
try (Tensor<Long> value = Tensors.create(data)) {
@@ -164,29 +170,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#INT64} constant with data from the given buffer.
- *
- * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
- * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
- * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
- * method.
- *
- * @param scope is a scope used to add the underlying operation.
- * @param shape the tensor shape.
- * @param data a buffer containing the tensor data.
- * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
- */
- public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) {
- try (Tensor<Long> value = Tensor.create(shape, data)) {
- return createWithTensor(scope, value);
- }
- }
-
- /**
* Creates a constant containing a single {@code boolean} element.
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a boolean constant
*/
public static Constant<Boolean> create(Scope scope, boolean data) {
try (Tensor<Boolean> value = Tensors.create(data)) {
@@ -199,6 +187,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The string to put into the new constant.
+ * @return a string constant
*/
public static Constant<String> create(Scope scope, String data) {
try (Tensor<String> value = Tensors.create(data)) {
@@ -212,6 +201,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param charset The encoding from String to bytes.
* @param data The string to put into the new constant.
+ * @return a string constant
*/
public static Constant<String> create(Scope scope, String data, Charset charset) {
try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) {
@@ -231,6 +221,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param type the tensor datatype.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a constant of type `type`
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
@@ -240,6 +231,28 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
}
+ /**
+ * Create a constant from a Java object.
+ *
+ * <p>The argument {@code object} is first converted into a Tensor using {@link
+ * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
+ * provided. For example:
+ *
+ * <pre>{@code
+ * Constant.create(scope, 7, Integer.class); // returns a constant scalar tensor 7
+ * }</pre>
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param object a Java object representing the constant.
+ * @return a constant of type `type`
+ * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ */
+ public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(object, type)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) {
return new Constant<T>(
scope