From dde0bf5051591b013b9eee131cd18af9a5c50ebf Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Thu, 2 Aug 2018 00:34:03 -0400 Subject: 1st code review: revert shape to long arrays --- .../src/main/java/org/tensorflow/DataType.java | 21 +- .../java/src/main/java/org/tensorflow/Session.java | 18 +- .../java/src/main/java/org/tensorflow/Shape.java | 42 +- .../java/src/main/java/org/tensorflow/Tensor.java | 2 +- .../main/java/org/tensorflow/op/core/Constant.java | 444 +++++++++++++++++++-- .../main/java/org/tensorflow/op/core/Zeros.java | 46 +-- .../java/org/tensorflow/op/core/ConstantTest.java | 69 ++-- .../java/org/tensorflow/op/core/ZerosTest.java | 113 +++--- 8 files changed, 566 insertions(+), 189 deletions(-) (limited to 'tensorflow/java') diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java index ded09974a4..9dfa9cc68c 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java @@ -39,7 +39,7 @@ public enum DataType { * *

TensorFlow uses the STRING type for an arbitrary sequence of bytes. */ - STRING(7), + STRING(7, -1), /** 64-bit signed integer. */ INT64(9, 8), @@ -49,29 +49,22 @@ public enum DataType { private final int value; - private final int sizeInBytes; + private final int byteSize; /** * @param value must match the corresponding TF_* value in the TensorFlow C API. + * @param byteSize size of an element of this type, in bytes, -1 if unknown */ - DataType(int value) { - this(value, -1); - } - - /** - * @param value must match the corresponding TF_* value in the TensorFlow C API. - * @param sizeInBytes size of an element of this type, in bytes, -1 if unknown - */ - DataType(int value, int sizeInBytes) { + DataType(int value, int byteSize) { this.value = value; - this.sizeInBytes = sizeInBytes; + this.byteSize = byteSize; } /** * @return size of an element of this type, in bytes, or -1 if element size is variable */ - public int sizeInBytes() { - return sizeInBytes; + public int byteSize() { + return byteSize; } /** Corresponding value of the TF_DataType enum in the TensorFlow C API. */ diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java index 73324f23e6..a660d25f98 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Session.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java @@ -185,11 +185,20 @@ public final class Session implements AutoCloseable { return this; } - /** Makes {@link #run()} return the Tensor referred to by {@code output}. */ + /** + * Makes {@link #run()} return the Tensor referred to by {@code output}. + */ public Runner fetch(Output output) { outputs.add(output); return this; } + + /** + * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. + */ + public Runner fetch(Operand operand) { + return fetch(operand.asOutput()); + } /** * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. @@ -209,6 +218,13 @@ public final class Session implements AutoCloseable { targets.add(operation); return this; } + + /** + * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s. + */ + public Runner addTarget(Operand operand) { + return addTarget(operand.asOutput().op()); + } /** * (Experimental method): set options (typically for debugging) for this run. diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java index d99b0078f6..d533c3d480 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java @@ -74,41 +74,7 @@ public final class Shape { * @return The size of the requested dimension or -1 if it is unknown. */ public long size(int i) { - return shape == null ? -1 : shape[i]; - } - - /** - * The total number of elements found in a tensor of this shape. - * - *

If the size of some dimensions is unknown, the total number of elements cannot be calculated and -1 is returned. - * - * @return the number of elements or -1 if size of some dimension are unknown - */ - public int numElements() { - if (shape == null) { - return -1; - } - int total = 1; - for (int i = 0; i < shape.length; ++i) { - // TODO (karllessard): There might be a lossy conversion here from 'long' sizes to 'int' total, but this issue - // seems ubiquitous in the current Java client implementation. It should be adressed all at once. - int size = (int) size(i); - if (size < 0) { - return -1; - } - total *= size; - } - return total; - } - - /** - * Returns the shape as an array. - * - *

Each element represent the size of the dimension at the given index. For example, - * {@code shape.asArray()[2]} is equal to the size of the third dimension of this shape. - */ - public long[] asArray() { - return shape; + return shape[i]; } @Override @@ -143,6 +109,12 @@ public final class Shape { this.shape = shape; } + // Package-private accessor. + // The idea is that the public API does not expose the internal array. + long[] asArray() { + return shape; + } + private long[] shape; private boolean hasUnknownDimension() { diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 38bb55e59f..8987253768 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -595,7 +595,7 @@ public final class Tensor implements AutoCloseable { } private static int elemByteSize(DataType dataType) { - int size = dataType.sizeInBytes(); + int size = dataType.byteSize(); if (size < 0) { throw new IllegalArgumentException("STRING tensors do not have a fixed element size"); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index c71046d983..d7a06380c7 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -15,6 +15,8 @@ limitations under the License. package org.tensorflow.op.core; +import static java.nio.charset.StandardCharsets.UTF_8; + import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; @@ -26,9 +28,7 @@ import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; import org.tensorflow.Output; -import org.tensorflow.Shape; import org.tensorflow.Tensor; -import org.tensorflow.Tensors; import org.tensorflow.op.PrimitiveOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Operator; @@ -51,8 +51,8 @@ public final class Constant extends PrimitiveOp implements Operand { * @return an integer constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Shape shape, IntBuffer data) { - try (Tensor value = Tensor.create(shape.asArray(), data)) { + public static Constant create(Scope scope, long[] shape, IntBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -65,9 +65,73 @@ public final class Constant extends PrimitiveOp implements Operand { * @return an integer constant */ public static Constant create(Scope scope, int data) { - try (Tensor value = Tensors.create(data)) { - return createWithTensor(scope, value); - } + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-1 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-2 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-3 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-4 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-5 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-6 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, int[][][][][][] data) { + return create(scope, data, Integer.class); } /** @@ -84,8 +148,8 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a float constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Shape shape, FloatBuffer data) { - try (Tensor value = Tensor.create(shape.asArray(), data)) { + public static Constant create(Scope scope, long[] shape, FloatBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -98,9 +162,73 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a float constant */ public static Constant create(Scope scope, float data) { - try (Tensor value = Tensors.create(data)) { - return createWithTensor(scope, value); - } + return create(scope, data, Float.class); + } + + /** + * Creates a rank-1 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-2 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-3 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-4 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-5 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-6 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, float[][][][][][] data) { + return create(scope, data, Float.class); } /** @@ -117,8 +245,8 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a double constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Shape shape, DoubleBuffer data) { - try (Tensor value = Tensor.create(shape.asArray(), data)) { + public static Constant create(Scope scope, long[] shape, DoubleBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -131,9 +259,73 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a double constant */ public static Constant create(Scope scope, double data) { - try (Tensor value = Tensors.create(data)) { - return createWithTensor(scope, value); - } + return create(scope, data, Double.class); + } + + /** + * Creates a rank-1 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-2 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-3 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-4 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-5 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-6 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, double[][][][][][] data) { + return create(scope, data, Double.class); } /** @@ -150,8 +342,8 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Shape shape, LongBuffer data) { - try (Tensor value = Tensor.create(shape.asArray(), data)) { + public static Constant create(Scope scope, long[] shape, LongBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -164,9 +356,73 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a long constant */ public static Constant create(Scope scope, long data) { - try (Tensor value = Tensors.create(data)) { - return createWithTensor(scope, value); - } + return create(scope, data, Long.class); + } + + /** + * Creates a rank-1 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-2 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-3 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-4 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-5 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-6 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, long[][][][][][] data) { + return create(scope, data, Long.class); } /** @@ -177,9 +433,73 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a boolean constant */ public static Constant create(Scope scope, boolean data) { - try (Tensor value = Tensors.create(data)) { - return createWithTensor(scope, value); - } + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-1 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-2 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-3 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-4 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-5 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-6 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant create(Scope scope, boolean[][][][][][] data) { + return create(scope, data, Boolean.class); } /** @@ -190,9 +510,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @return a string constant */ public static Constant create(Scope scope, String data) { - try (Tensor value = Tensors.create(data)) { - return createWithTensor(scope, value); - } + return create(scope, data, UTF_8); } /** @@ -225,12 +543,78 @@ public final class Constant extends PrimitiveOp implements Operand { * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ - public static Constant create(Scope scope, Class type, Shape shape, ByteBuffer data) { - try (Tensor value = Tensor.create(type, shape.asArray(), data)) { + public static Constant create(Scope scope, Class type, long[] shape, ByteBuffer data) { + try (Tensor value = Tensor.create(type, shape, data)) { return createWithTensor(scope, value); } } + /** + * Creates a rank-1 constant of {@code byte} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-2 constant of {@code byte} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-3 constant of {@code byte} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-4 constant of {@code byte} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-5 constant of {@code byte} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-6 constant of {@code byte} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant create(Scope scope, byte[][][][][][] data) { + return create(scope, data, String.class); + } + /** * Create a constant from a Java object. * diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java index 5bba594e17..cc46ce3c5b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java @@ -5,13 +5,17 @@ import java.nio.ByteBuffer; import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Operator; /** - * An operator creating a constant initialized with zeros w.r.t its type and shape. + * An operator creating a constant initialized with zeros of the shape given by `dims`. + * + *

For example, the following expression + *

{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)
+ * is the equivalent of + *
{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))
* * @param constant type */ @@ -19,38 +23,32 @@ import org.tensorflow.op.annotation.Operator; public class Zeros implements Op, Operand { /** - * Factory method for this operator + * Creates a zeroed tensor given its type and shape. * - * @param scope is a scope used to add the underlying operation. - * @param type the tensor datatype. - * @param shape the tensor shape. - * @return a constant initialized with zeros + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype + * @return a constant tensor initialized with zeros * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. */ - public static Zeros create(Scope scope, Class type, Shape shape) { - int numElements = shape.numElements(); - if (numElements < 0) { - throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants"); + public static Zeros create(Scope scope, Operand dims, Class type) { + Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros" + int zeroSize = DataType.fromClass(type).byteSize(); + if (zeroSize < 0) { + throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros"); } - int sizeInBytes = DataType.fromClass(type).sizeInBytes(); - if (sizeInBytes < 0) { - throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros"); - } - return new Zeros(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes))); + Constant zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize)); + return new Zeros(Fill.create(childScope, dims, zero)); } @Override public Output asOutput() { - return constant.asOutput(); - } - - public Constant constant() { - return constant; + return fill.asOutput(); } - private final Constant constant; + private final Fill fill; - private Zeros(Constant constant) { - this.constant = constant; + private Zeros(Fill fill) { + this.fill = fill; } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java index 63e191cd38..7d3b26de8d 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -33,7 +33,6 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Shape; import org.tensorflow.Tensor; import org.tensorflow.op.Scope; @@ -49,23 +48,25 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, value); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); - assertEquals(value, result.intValue()); + try (Tensor result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) { + assertEquals(value, result.intValue()); + } } } @Test public void createIntBuffer() { int[] ints = {1, 2, 3, 4}; - Shape shape = Shape.make(4); + long[] shape = {4}; try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); - int[] actual = new int[ints.length]; - assertArrayEquals(ints, result.copyTo(actual)); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + int[] actual = new int[ints.length]; + assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual)); + } } } @@ -77,23 +78,25 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, value); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - assertEquals(value, result.floatValue(), 0.0f); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Float.class).floatValue(), 0.0f); + } } } @Test public void createFloatBuffer() { float[] floats = {1, 2, 3, 4}; - Shape shape = Shape.make(4); + long[] shape = {4}; try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - float[] actual = new float[floats.length]; - assertArrayEquals(floats, result.copyTo(actual), EPSILON); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + float[] actual = new float[floats.length]; + assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON); + } } } @@ -105,23 +108,25 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, value); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - assertEquals(value, result.doubleValue(), 0.0); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Double.class).doubleValue(), 0.0); + } } } @Test public void createDoubleBuffer() { double[] doubles = {1, 2, 3, 4}; - Shape shape = Shape.make(4); + long[] shape = {4}; try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - double[] actual = new double[doubles.length]; - assertArrayEquals(doubles, result.copyTo(actual), EPSILON); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + double[] actual = new double[doubles.length]; + assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON); + } } } @@ -133,23 +138,25 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, value); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - assertEquals(value, result.longValue()); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Long.class).longValue()); + } } } @Test public void createLongBuffer() { long[] longs = {1, 2, 3, 4}; - Shape shape = Shape.make(4); + long[] shape = {4}; try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - long[] actual = new long[longs.length]; - assertArrayEquals(longs, result.copyTo(actual)); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + long[] actual = new long[longs.length]; + assertArrayEquals(longs, result.expect(Long.class).copyTo(actual)); + } } } @@ -161,15 +168,16 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, value); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class); - assertEquals(value, result.booleanValue()); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Boolean.class).booleanValue()); + } } } @Test public void createStringBuffer() throws IOException { byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; - Shape shape = Shape.scalar(); + long[] shape = {}; // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer, // followed by a varint encoded size, followed by the data. @@ -190,8 +198,9 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class); - assertArrayEquals(data, result.bytesValue()); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + assertArrayEquals(data, result.expect(String.class).bytesValue()); + } } } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java index ab3446b72b..24339d92e6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -18,12 +18,13 @@ package org.tensorflow.op.core; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import java.util.List; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Shape; import org.tensorflow.Tensor; import org.tensorflow.op.Scope; import org.tensorflow.types.UInt8; @@ -37,14 +38,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros op = Zeros.create(scope, Integer.class, Shape.make(2, 2)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); - int[][] actual = new int[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0, actual[i][j]); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Integer.class); + try (Tensor result = sess.runner().fetch(op).run().get(0)) { + int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } } } } @@ -55,14 +56,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros op = Zeros.create(scope, Float.class, Shape.make(2, 2)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - float[][] actual = new float[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0.0f, actual[i][j], EPSILON); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Float.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0f, actual[i][j], EPSILON); + } } } } @@ -73,14 +74,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros op = Zeros.create(scope, Double.class, Shape.make(2, 2)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - double[][] actual = new double[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0.0, actual[i][j], EPSILON); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Double.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0, actual[i][j], EPSILON); + } } } } @@ -91,14 +92,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros op = Zeros.create(scope, Long.class, Shape.make(2, 2)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - long[][] actual = new long[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0L, actual[i][j]); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Long.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0L, actual[i][j]); + } } } } @@ -109,14 +110,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros op = Zeros.create(scope, Boolean.class, Shape.make(2, 2)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class); - boolean[][] actual = new boolean[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertFalse(actual[i][j]); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertFalse(actual[i][j]); + } } } } @@ -127,14 +128,15 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros op = Zeros.create(scope, UInt8.class, Shape.make(2, 2)); - Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class); - byte[][] actual = new byte[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0, actual[i][j]); + long[] shape = {2, 2}; + Zeros op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0)) { + byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]); + result.copyTo(actual); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } } } } @@ -145,16 +147,19 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Zeros.create(scope, String.class, Shape.make(2, 2)); + long[] shape = {2, 2}; + Zeros.create(scope, Constant.create(scope, shape), String.class); } } - - @Test(expected = IllegalArgumentException.class) - public void cannotCreateZerosWithUnknownDimensions() { + + @Test + public void operationsComposingZerosAreCorrectlyNamed() { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Zeros.create(scope, Float.class, Shape.make(2, -1)); + long[] shape = {2, 2}; + Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class); + List> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } } -- cgit v1.2.3