aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-07-13 00:25:14 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-08-02 00:34:31 -0400
commit3359a5fdedb9988ed53879c85e63259d9cefc889 (patch)
tree432a4d5771c81aa7b9251b34e0ee57e059adf2d9 /tensorflow/java/src
parent3379bae787d73d6db67d66a284bd1a076b2cbdba (diff)
Initial draft for Zeros, and add new factories to Constants
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java22
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensors.java1
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java100
3 files changed, 106 insertions, 17 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 24a3775db6..6e82efdf53 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -164,8 +164,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Integer> create(long[] shape, IntBuffer data) {
- Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining());
+ public static Tensor<Integer> create(Shape shape, IntBuffer data) {
+ Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape.asArray(), data.remaining());
t.buffer().asIntBuffer().put(data);
return t;
}
@@ -182,8 +182,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Float> create(long[] shape, FloatBuffer data) {
- Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
+ public static Tensor<Float> create(Shape shape, FloatBuffer data) {
+ Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape.asArray(), data.remaining());
t.buffer().asFloatBuffer().put(data);
return t;
}
@@ -200,8 +200,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Double> create(long[] shape, DoubleBuffer data) {
- Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
+ public static Tensor<Double> create(Shape shape, DoubleBuffer data) {
+ Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape.asArray(), data.remaining());
t.buffer().asDoubleBuffer().put(data);
return t;
}
@@ -218,8 +218,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Long> create(long[] shape, LongBuffer data) {
- Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining());
+ public static Tensor<Long> create(Shape shape, LongBuffer data) {
+ Tensor<Long> t = allocateForBuffer(DataType.INT64, shape.asArray(), data.remaining());
t.buffer().asLongBuffer().put(data);
return t;
}
@@ -239,7 +239,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) {
+ public static <T> Tensor<T> create(Class<T> type, Shape shape, ByteBuffer data) {
@SuppressWarnings("unchecked")
Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data);
return ret;
@@ -260,7 +260,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) {
+ private static Tensor<?> create(DataType dtype, Shape shape, ByteBuffer data) {
int nremaining = 0;
if (dtype != DataType.STRING) {
int elemBytes = elemByteSize(dtype);
@@ -274,7 +274,7 @@ public final class Tensor<T> implements AutoCloseable {
} else {
nremaining = data.remaining();
}
- Tensor<?> t = allocateForBuffer(dtype, shape, nremaining);
+ Tensor<?> t = allocateForBuffer(dtype, shape.asArray(), nremaining);
t.buffer().put(data);
return t;
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
index c828d23efc..c6c3117db2 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
@@ -20,7 +20,6 @@ import static java.nio.charset.StandardCharsets.UTF_8;
/** Type-safe factory methods for creating {@link org.tensorflow.Tensor} objects. */
public final class Tensors {
private Tensors() {}
-
/**
* Creates a scalar String tensor using the default, UTF-8 encoding.
*
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index de4049f66b..bcf165346c 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -20,11 +20,15 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+import java.nio.charset.Charset;
+
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Operation;
import org.tensorflow.Output;
+import org.tensorflow.Shape;
import org.tensorflow.Tensor;
+import org.tensorflow.Tensors;
import org.tensorflow.op.PrimitiveOp;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;
@@ -32,6 +36,7 @@ import org.tensorflow.op.annotation.Operator;
/** An operator producing a constant value. */
@Operator
public final class Constant<T> extends PrimitiveOp implements Operand<T> {
+
/**
* Create a constant from a Java object.
*
@@ -54,6 +59,18 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code int} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ */
+ public static Constant<Integer> create(Scope scope, int data) {
+ try (Tensor<Integer> value = Tensors.create(data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
* Create a {@link DataType#INT32} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -66,13 +83,25 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) {
+ public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) {
try (Tensor<Integer> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
+ * Creates a constant containing a single {@code float} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ */
+ public static Constant<Float> create(Scope scope, float data) {
+ try (Tensor<Float> value = Tensors.create(data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
* Create a {@link DataType#FLOAT} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -85,13 +114,25 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) {
+ public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) {
try (Tensor<Float> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
+ * Creates a constant containing a single {@code double} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ */
+ public static Constant<Double> create(Scope scope, double data) {
+ try (Tensor<Double> value = Tensors.create(data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
* Create a {@link DataType#DOUBLE} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -104,13 +145,25 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) {
+ public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) {
try (Tensor<Double> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
+ * Creates a constant containing a single {@code long} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ */
+ public static Constant<Long> create(Scope scope, long data) {
+ try (Tensor<Long> value = Tensors.create(data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
* Create a {@link DataType#INT64} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -123,13 +176,50 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) {
+ public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) {
try (Tensor<Long> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
/**
+ * Creates a constant containing a single {@code boolean} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean data) {
+ try (Tensor<Boolean> value = Tensors.create(data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
+ * Creates a String constant using the default, UTF-8 encoding.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The string to put into the new constant.
+ */
+ public static Constant<String> create(Scope scope, String data) {
+ try (Tensor<String> value = Tensors.create(data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
+ * Creates a String constant using a specified encoding.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param charset The encoding from String to bytes.
+ * @param data The string to put into the new constant.
+ */
+ public static Constant<String> create(Scope scope, String data, Charset charset) {
+ try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) {
+ return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class));
+ }
+ }
+
+ /**
* Create a constant with data from the given buffer.
*
* <p>Creates a Constant with the provided shape of any type where the constant data has been
@@ -144,7 +234,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static <T> Constant<T> create(Scope scope, Class<T> type, long[] shape, ByteBuffer data) {
+ public static <T> Constant<T> create(Scope scope, Class<T> type, Shape shape, ByteBuffer data) {
try (Tensor<T> value = Tensor.create(type, shape, data)) {
return createWithTensor(scope, value);
}