diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-07-13 00:25:14 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-08-02 00:34:31 -0400 |
commit | 3359a5fdedb9988ed53879c85e63259d9cefc889 (patch) | |
tree | 432a4d5771c81aa7b9251b34e0ee57e059adf2d9 /tensorflow/java/src | |
parent | 3379bae787d73d6db67d66a284bd1a076b2cbdba (diff) |
Initial draft for Zeros, and add new factories to Constants
Diffstat (limited to 'tensorflow/java/src')
3 files changed, 106 insertions, 17 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 24a3775db6..6e82efdf53 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -164,8 +164,8 @@ public final class Tensor<T> implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor<Integer> create(long[] shape, IntBuffer data) { - Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining()); + public static Tensor<Integer> create(Shape shape, IntBuffer data) { + Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape.asArray(), data.remaining()); t.buffer().asIntBuffer().put(data); return t; } @@ -182,8 +182,8 @@ public final class Tensor<T> implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor<Float> create(long[] shape, FloatBuffer data) { - Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining()); + public static Tensor<Float> create(Shape shape, FloatBuffer data) { + Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape.asArray(), data.remaining()); t.buffer().asFloatBuffer().put(data); return t; } @@ -200,8 +200,8 @@ public final class Tensor<T> implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor<Double> create(long[] shape, DoubleBuffer data) { - Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining()); + public static Tensor<Double> create(Shape shape, DoubleBuffer data) { + Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape.asArray(), data.remaining()); t.buffer().asDoubleBuffer().put(data); return t; } @@ -218,8 +218,8 @@ public final class Tensor<T> implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor<Long> create(long[] shape, LongBuffer data) { - Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining()); + public static Tensor<Long> create(Shape shape, LongBuffer data) { + Tensor<Long> t = allocateForBuffer(DataType.INT64, shape.asArray(), data.remaining()); t.buffer().asLongBuffer().put(data); return t; } @@ -239,7 +239,7 @@ public final class Tensor<T> implements AutoCloseable { * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ - public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) { + public static <T> Tensor<T> create(Class<T> type, Shape shape, ByteBuffer data) { @SuppressWarnings("unchecked") Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data); return ret; @@ -260,7 +260,7 @@ public final class Tensor<T> implements AutoCloseable { * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ - private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) { + private static Tensor<?> create(DataType dtype, Shape shape, ByteBuffer data) { int nremaining = 0; if (dtype != DataType.STRING) { int elemBytes = elemByteSize(dtype); @@ -274,7 +274,7 @@ public final class Tensor<T> implements AutoCloseable { } else { nremaining = data.remaining(); } - Tensor<?> t = allocateForBuffer(dtype, shape, nremaining); + Tensor<?> t = allocateForBuffer(dtype, shape.asArray(), nremaining); t.buffer().put(data); return t; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java index c828d23efc..c6c3117db2 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java @@ -20,7 +20,6 @@ import static java.nio.charset.StandardCharsets.UTF_8; /** Type-safe factory methods for creating {@link org.tensorflow.Tensor} objects. */ public final class Tensors { private Tensors() {} - /** * Creates a scalar String tensor using the default, UTF-8 encoding. * diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index de4049f66b..bcf165346c 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -20,11 +20,15 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; +import java.nio.charset.Charset; + import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; +import org.tensorflow.Shape; import org.tensorflow.Tensor; +import org.tensorflow.Tensors; import org.tensorflow.op.PrimitiveOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Operator; @@ -32,6 +36,7 @@ import org.tensorflow.op.annotation.Operator; /** An operator producing a constant value. */ @Operator public final class Constant<T> extends PrimitiveOp implements Operand<T> { + /** * Create a constant from a Java object. * @@ -54,6 +59,18 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code int} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + */ + public static Constant<Integer> create(Scope scope, int data) { + try (Tensor<Integer> value = Tensors.create(data)) { + return createWithTensor(scope, value); + } + } + + /** * Create a {@link DataType#INT32} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -66,13 +83,25 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) { + public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) { try (Tensor<Integer> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } /** + * Creates a constant containing a single {@code float} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + */ + public static Constant<Float> create(Scope scope, float data) { + try (Tensor<Float> value = Tensors.create(data)) { + return createWithTensor(scope, value); + } + } + + /** * Create a {@link DataType#FLOAT} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -85,13 +114,25 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) { + public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) { try (Tensor<Float> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } /** + * Creates a constant containing a single {@code double} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + */ + public static Constant<Double> create(Scope scope, double data) { + try (Tensor<Double> value = Tensors.create(data)) { + return createWithTensor(scope, value); + } + } + + /** * Create a {@link DataType#DOUBLE} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -104,13 +145,25 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) { + public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) { try (Tensor<Double> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } /** + * Creates a constant containing a single {@code long} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + */ + public static Constant<Long> create(Scope scope, long data) { + try (Tensor<Long> value = Tensors.create(data)) { + return createWithTensor(scope, value); + } + } + + /** * Create a {@link DataType#INT64} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -123,13 +176,50 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) { + public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) { try (Tensor<Long> value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } /** + * Creates a constant containing a single {@code boolean} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + */ + public static Constant<Boolean> create(Scope scope, boolean data) { + try (Tensor<Boolean> value = Tensors.create(data)) { + return createWithTensor(scope, value); + } + } + + /** + * Creates a String constant using the default, UTF-8 encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param data The string to put into the new constant. + */ + public static Constant<String> create(Scope scope, String data) { + try (Tensor<String> value = Tensors.create(data)) { + return createWithTensor(scope, value); + } + } + + /** + * Creates a String constant using a specified encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param charset The encoding from String to bytes. + * @param data The string to put into the new constant. + */ + public static Constant<String> create(Scope scope, String data, Charset charset) { + try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) { + return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class)); + } + } + + /** * Create a constant with data from the given buffer. * * <p>Creates a Constant with the provided shape of any type where the constant data has been @@ -144,7 +234,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ - public static <T> Constant<T> create(Scope scope, Class<T> type, long[] shape, ByteBuffer data) { + public static <T> Constant<T> create(Scope scope, Class<T> type, Shape shape, ByteBuffer data) { try (Tensor<T> value = Tensor.create(type, shape, data)) { return createWithTensor(scope, value); } |