aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-08-02 00:34:03 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-08-02 08:54:48 -0400
commitdde0bf5051591b013b9eee131cd18af9a5c50ebf (patch)
tree938c0352c4ecb9fb078c90e5ada79941a7ca8f5b /tensorflow/java
parent3c9cee97d3feed1354f85dbb5a13564eaaf866b4 (diff)
1st code review: revert shape to long arrays
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/DataType.java21
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Session.java18
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Shape.java42
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java2
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java444
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java46
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java69
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java113
8 files changed, 566 insertions, 189 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
index ded09974a4..9dfa9cc68c 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
@@ -39,7 +39,7 @@ public enum DataType {
*
* <p>TensorFlow uses the STRING type for an arbitrary sequence of bytes.
*/
- STRING(7),
+ STRING(7, -1),
/** 64-bit signed integer. */
INT64(9, 8),
@@ -49,29 +49,22 @@ public enum DataType {
private final int value;
- private final int sizeInBytes;
+ private final int byteSize;
/**
* @param value must match the corresponding TF_* value in the TensorFlow C API.
+ * @param byteSize size of an element of this type, in bytes, -1 if unknown
*/
- DataType(int value) {
- this(value, -1);
- }
-
- /**
- * @param value must match the corresponding TF_* value in the TensorFlow C API.
- * @param sizeInBytes size of an element of this type, in bytes, -1 if unknown
- */
- DataType(int value, int sizeInBytes) {
+ DataType(int value, int byteSize) {
this.value = value;
- this.sizeInBytes = sizeInBytes;
+ this.byteSize = byteSize;
}
/**
* @return size of an element of this type, in bytes, or -1 if element size is variable
*/
- public int sizeInBytes() {
- return sizeInBytes;
+ public int byteSize() {
+ return byteSize;
}
/** Corresponding value of the TF_DataType enum in the TensorFlow C API. */
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java
index 73324f23e6..a660d25f98 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java
@@ -185,11 +185,20 @@ public final class Session implements AutoCloseable {
return this;
}
- /** Makes {@link #run()} return the Tensor referred to by {@code output}. */
+ /**
+ * Makes {@link #run()} return the Tensor referred to by {@code output}.
+ */
public Runner fetch(Output<?> output) {
outputs.add(output);
return this;
}
+
+ /**
+ * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
+ */
+ public Runner fetch(Operand<?> operand) {
+ return fetch(operand.asOutput());
+ }
/**
* Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s.
@@ -209,6 +218,13 @@ public final class Session implements AutoCloseable {
targets.add(operation);
return this;
}
+
+ /**
+ * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s.
+ */
+ public Runner addTarget(Operand<?> operand) {
+ return addTarget(operand.asOutput().op());
+ }
/**
* (Experimental method): set options (typically for debugging) for this run.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
index d99b0078f6..d533c3d480 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
@@ -74,41 +74,7 @@ public final class Shape {
* @return The size of the requested dimension or -1 if it is unknown.
*/
public long size(int i) {
- return shape == null ? -1 : shape[i];
- }
-
- /**
- * The total number of elements found in a tensor of this shape.
- *
- * <p>If the size of some dimensions is unknown, the total number of elements cannot be calculated and -1 is returned.
- *
- * @return the number of elements or -1 if size of some dimension are unknown
- */
- public int numElements() {
- if (shape == null) {
- return -1;
- }
- int total = 1;
- for (int i = 0; i < shape.length; ++i) {
- // TODO (karllessard): There might be a lossy conversion here from 'long' sizes to 'int' total, but this issue
- // seems ubiquitous in the current Java client implementation. It should be adressed all at once.
- int size = (int) size(i);
- if (size < 0) {
- return -1;
- }
- total *= size;
- }
- return total;
- }
-
- /**
- * Returns the shape as an array.
- *
- * <p>Each element represent the size of the dimension at the given index. For example,
- * {@code shape.asArray()[2]} is equal to the size of the third dimension of this shape.
- */
- public long[] asArray() {
- return shape;
+ return shape[i];
}
@Override
@@ -143,6 +109,12 @@ public final class Shape {
this.shape = shape;
}
+ // Package-private accessor.
+ // The idea is that the public API does not expose the internal array.
+ long[] asArray() {
+ return shape;
+ }
+
private long[] shape;
private boolean hasUnknownDimension() {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 38bb55e59f..8987253768 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -595,7 +595,7 @@ public final class Tensor<T> implements AutoCloseable {
}
private static int elemByteSize(DataType dataType) {
- int size = dataType.sizeInBytes();
+ int size = dataType.byteSize();
if (size < 0) {
throw new IllegalArgumentException("STRING tensors do not have a fixed element size");
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index c71046d983..d7a06380c7 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -15,6 +15,8 @@ limitations under the License.
package org.tensorflow.op.core;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
@@ -26,9 +28,7 @@ import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Operation;
import org.tensorflow.Output;
-import org.tensorflow.Shape;
import org.tensorflow.Tensor;
-import org.tensorflow.Tensors;
import org.tensorflow.op.PrimitiveOp;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;
@@ -51,8 +51,8 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return an integer constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) {
- try (Tensor<Integer> value = Tensor.create(shape.asArray(), data)) {
+ public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) {
+ try (Tensor<Integer> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -65,9 +65,73 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return an integer constant
*/
public static Constant<Integer> create(Scope scope, int data) {
- try (Tensor<Integer> value = Tensors.create(data)) {
- return createWithTensor(scope, value);
- }
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][][][] data) {
+ return create(scope, data, Integer.class);
}
/**
@@ -84,8 +148,8 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a float constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) {
- try (Tensor<Float> value = Tensor.create(shape.asArray(), data)) {
+ public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) {
+ try (Tensor<Float> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -98,9 +162,73 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a float constant
*/
public static Constant<Float> create(Scope scope, float data) {
- try (Tensor<Float> value = Tensors.create(data)) {
- return createWithTensor(scope, value);
- }
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][][][] data) {
+ return create(scope, data, Float.class);
}
/**
@@ -117,8 +245,8 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a double constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) {
- try (Tensor<Double> value = Tensor.create(shape.asArray(), data)) {
+ public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) {
+ try (Tensor<Double> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -131,9 +259,73 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a double constant
*/
public static Constant<Double> create(Scope scope, double data) {
- try (Tensor<Double> value = Tensors.create(data)) {
- return createWithTensor(scope, value);
- }
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][][][] data) {
+ return create(scope, data, Double.class);
}
/**
@@ -150,8 +342,8 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a long constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) {
- try (Tensor<Long> value = Tensor.create(shape.asArray(), data)) {
+ public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) {
+ try (Tensor<Long> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -164,9 +356,73 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a long constant
*/
public static Constant<Long> create(Scope scope, long data) {
- try (Tensor<Long> value = Tensors.create(data)) {
- return createWithTensor(scope, value);
- }
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][][][] data) {
+ return create(scope, data, Long.class);
}
/**
@@ -177,9 +433,73 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a boolean constant
*/
public static Constant<Boolean> create(Scope scope, boolean data) {
- try (Tensor<Boolean> value = Tensors.create(data)) {
- return createWithTensor(scope, value);
- }
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][][][] data) {
+ return create(scope, data, Boolean.class);
}
/**
@@ -190,9 +510,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @return a string constant
*/
public static Constant<String> create(Scope scope, String data) {
- try (Tensor<String> value = Tensors.create(data)) {
- return createWithTensor(scope, value);
- }
+ return create(scope, data, UTF_8);
}
/**
@@ -225,13 +543,79 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static <T> Constant<T> create(Scope scope, Class<T> type, Shape shape, ByteBuffer data) {
- try (Tensor<T> value = Tensor.create(type, shape.asArray(), data)) {
+ public static <T> Constant<T> create(Scope scope, Class<T> type, long[] shape, ByteBuffer data) {
+ try (Tensor<T> value = Tensor.create(type, shape, data)) {
return createWithTensor(scope, value);
}
}
/**
+ * Creates a rank-1 constant of {@code byte} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code byte} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code byte} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code byte} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code byte} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code byte} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
* Create a constant from a Java object.
*
* <p>The argument {@code object} is first converted into a Tensor using {@link
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
index 5bba594e17..cc46ce3c5b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
@@ -5,13 +5,17 @@ import java.nio.ByteBuffer;
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Output;
-import org.tensorflow.Shape;
import org.tensorflow.op.Op;
import org.tensorflow.op.Scope;
import org.tensorflow.op.annotation.Operator;
/**
- * An operator creating a constant initialized with zeros w.r.t its type and shape.
+ * An operator creating a constant initialized with zeros of the shape given by `dims`.
+ *
+ * <p>For example, the following expression
+ * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre>
+ * is the equivalent of
+ * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre>
*
* @param <T> constant type
*/
@@ -19,38 +23,32 @@ import org.tensorflow.op.annotation.Operator;
public class Zeros<T> implements Op, Operand<T> {
/**
- * Factory method for this operator
+ * Creates a zeroed tensor given its type and shape.
*
- * @param scope is a scope used to add the underlying operation.
- * @param type the tensor datatype.
- * @param shape the tensor shape.
- * @return a constant initialized with zeros
+ * @param scope is a scope used to add the underlying operation
+ * @param dims a 1-D operand that represents the shape of the output tensor
+ * @param type the output tensor datatype
+ * @return a constant tensor initialized with zeros
* @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
*/
- public static <T> Zeros<T> create(Scope scope, Class<T> type, Shape shape) {
- int numElements = shape.numElements();
- if (numElements < 0) {
- throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants");
+ public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) {
+ Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros"
+ int zeroSize = DataType.fromClass(type).byteSize();
+ if (zeroSize < 0) {
+ throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros");
}
- int sizeInBytes = DataType.fromClass(type).sizeInBytes();
- if (sizeInBytes < 0) {
- throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros");
- }
- return new Zeros<T>(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes)));
+ Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize));
+ return new Zeros<T>(Fill.create(childScope, dims, zero));
}
@Override
public Output<T> asOutput() {
- return constant.asOutput();
- }
-
- public Constant<T> constant() {
- return constant;
+ return fill.asOutput();
}
- private final Constant<T> constant;
+ private final Fill<T> fill;
- private Zeros(Constant<T> constant) {
- this.constant = constant;
+ private Zeros(Fill<T> fill) {
+ this.fill = fill;
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index 63e191cd38..7d3b26de8d 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -33,7 +33,6 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.Graph;
import org.tensorflow.Session;
-import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.op.Scope;
@@ -49,23 +48,25 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Integer> op = Constant.create(scope, value);
- Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
- assertEquals(value, result.intValue());
+ try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) {
+ assertEquals(value, result.intValue());
+ }
}
}
@Test
public void createIntBuffer() {
int[] ints = {1, 2, 3, 4};
- Shape shape = Shape.make(4);
+ long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
- Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
- int[] actual = new int[ints.length];
- assertArrayEquals(ints, result.copyTo(actual));
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ int[] actual = new int[ints.length];
+ assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual));
+ }
}
}
@@ -77,23 +78,25 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Float> op = Constant.create(scope, value);
- Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
- assertEquals(value, result.floatValue(), 0.0f);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Float.class).floatValue(), 0.0f);
+ }
}
}
@Test
public void createFloatBuffer() {
float[] floats = {1, 2, 3, 4};
- Shape shape = Shape.make(4);
+ long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
- Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
- float[] actual = new float[floats.length];
- assertArrayEquals(floats, result.copyTo(actual), EPSILON);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ float[] actual = new float[floats.length];
+ assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON);
+ }
}
}
@@ -105,23 +108,25 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Double> op = Constant.create(scope, value);
- Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
- assertEquals(value, result.doubleValue(), 0.0);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Double.class).doubleValue(), 0.0);
+ }
}
}
@Test
public void createDoubleBuffer() {
double[] doubles = {1, 2, 3, 4};
- Shape shape = Shape.make(4);
+ long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
- Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
- double[] actual = new double[doubles.length];
- assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON);
+ }
}
}
@@ -133,23 +138,25 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Long> op = Constant.create(scope, value);
- Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
- assertEquals(value, result.longValue());
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Long.class).longValue());
+ }
}
}
@Test
public void createLongBuffer() {
long[] longs = {1, 2, 3, 4};
- Shape shape = Shape.make(4);
+ long[] shape = {4};
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs));
- Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
- long[] actual = new long[longs.length];
- assertArrayEquals(longs, result.copyTo(actual));
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ long[] actual = new long[longs.length];
+ assertArrayEquals(longs, result.expect(Long.class).copyTo(actual));
+ }
}
}
@@ -161,15 +168,16 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Boolean> op = Constant.create(scope, value);
- Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class);
- assertEquals(value, result.booleanValue());
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Boolean.class).booleanValue());
+ }
}
}
@Test
public void createStringBuffer() throws IOException {
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
- Shape shape = Shape.scalar();
+ long[] shape = {};
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
// followed by a varint encoded size, followed by the data.
@@ -190,8 +198,9 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content));
- Tensor<String> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class);
- assertArrayEquals(data, result.bytesValue());
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertArrayEquals(data, result.expect(String.class).bytesValue());
+ }
}
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
index ab3446b72b..24339d92e6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
@@ -18,12 +18,13 @@ package org.tensorflow.op.core;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import java.util.List;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.Graph;
import org.tensorflow.Session;
-import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.op.Scope;
import org.tensorflow.types.UInt8;
@@ -37,14 +38,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Integer> op = Zeros.create(scope, Integer.class, Shape.make(2, 2));
- Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
- int[][] actual = new int[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0, actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
}
}
}
@@ -55,14 +56,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Float> op = Zeros.create(scope, Float.class, Shape.make(2, 2));
- Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
- float[][] actual = new float[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0.0f, actual[i][j], EPSILON);
+ long[] shape = {2, 2};
+ Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0f, actual[i][j], EPSILON);
+ }
}
}
}
@@ -73,14 +74,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Double> op = Zeros.create(scope, Double.class, Shape.make(2, 2));
- Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
- double[][] actual = new double[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0.0, actual[i][j], EPSILON);
+ long[] shape = {2, 2};
+ Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0, actual[i][j], EPSILON);
+ }
}
}
}
@@ -91,14 +92,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Long> op = Zeros.create(scope, Long.class, Shape.make(2, 2));
- Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
- long[][] actual = new long[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0L, actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0L, actual[i][j]);
+ }
}
}
}
@@ -109,14 +110,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Boolean> op = Zeros.create(scope, Boolean.class, Shape.make(2, 2));
- Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class);
- boolean[][] actual = new boolean[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertFalse(actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertFalse(actual[i][j]);
+ }
}
}
}
@@ -127,14 +128,15 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<UInt8> op = Zeros.create(scope, UInt8.class, Shape.make(2, 2));
- Tensor<UInt8> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class);
- byte[][] actual = new byte[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0, actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]);
+ result.copyTo(actual);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
}
}
}
@@ -145,16 +147,19 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Zeros.create(scope, String.class, Shape.make(2, 2));
+ long[] shape = {2, 2};
+ Zeros.create(scope, Constant.create(scope, shape), String.class);
}
}
-
- @Test(expected = IllegalArgumentException.class)
- public void cannotCreateZerosWithUnknownDimensions() {
+
+ @Test
+ public void operationsComposingZerosAreCorrectlyNamed() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Zeros.create(scope, Float.class, Shape.make(2, -1));
+ long[] shape = {2, 2};
+ Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class);
+ List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
}
}
}