From 4a8cedd26c182b8f866ee3194c4a016d336ec907 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Tue, 24 Jul 2018 08:52:02 -0400 Subject: Add unit tests --- .../java/src/main/java/org/tensorflow/Shape.java | 24 ++++--- .../java/src/main/java/org/tensorflow/Tensor.java | 22 +++---- .../java/src/main/java/org/tensorflow/Tensors.java | 1 + .../main/java/org/tensorflow/op/core/Constant.java | 12 ++-- .../main/java/org/tensorflow/op/core/Zeros.java | 2 +- .../java/org/tensorflow/op/core/ConstantTest.java | 2 +- .../java/org/tensorflow/op/core/ZerosTest.java | 75 +++++++++++++--------- 7 files changed, 79 insertions(+), 59 deletions(-) (limited to 'tensorflow/java/src') diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java index 1662a49cb7..a177cdaf7a 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java @@ -74,7 +74,7 @@ public final class Shape { * @return The size of the requested dimension or -1 if it is unknown. */ public long size(int i) { - return shape[i]; + return shape == null ? -1 : shape[i]; } /** @@ -88,9 +88,11 @@ public final class Shape { if (shape == null) { return -1; } - long total = 1; + int total = 1; for (int i = 0; i < shape.length; ++i) { - long size = size(i); + // TODO (karllessard): There might be a lossy conversion here from 'long' sizes to 'int' total, but this issue + // seems ubiquitous in the current Java client implementation. It should be adressed all at once. + int size = (int) size(i); if (size < 0) { return -1; } @@ -99,6 +101,16 @@ public final class Shape { return total; } + /** + * Returns the shape as an array. + * + *

Each element represent the size of the dimension at the given index. For example, + * {@code shape.asArray()[4]} is equal to the size of the fourth dimension in this shape. + */ + public long[] asArray() { + return shape; + } + @Override public int hashCode() { return Arrays.hashCode(shape); @@ -131,12 +143,6 @@ public final class Shape { this.shape = shape; } - // Package-private accessor. - // The idea is that the public API does not expose the internal array. - long[] asArray() { - return shape; - } - private long[] shape; private boolean hasUnknownDimension() { diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index a307269ab5..38bb55e59f 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -164,8 +164,8 @@ public final class Tensor implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor create(Shape shape, IntBuffer data) { - Tensor t = allocateForBuffer(DataType.INT32, shape.asArray(), data.remaining()); + public static Tensor create(long[] shape, IntBuffer data) { + Tensor t = allocateForBuffer(DataType.INT32, shape, data.remaining()); t.buffer().asIntBuffer().put(data); return t; } @@ -182,8 +182,8 @@ public final class Tensor implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor create(Shape shape, FloatBuffer data) { - Tensor t = allocateForBuffer(DataType.FLOAT, shape.asArray(), data.remaining()); + public static Tensor create(long[] shape, FloatBuffer data) { + Tensor t = allocateForBuffer(DataType.FLOAT, shape, data.remaining()); t.buffer().asFloatBuffer().put(data); return t; } @@ -200,8 +200,8 @@ public final class Tensor implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor create(Shape shape, DoubleBuffer data) { - Tensor t = allocateForBuffer(DataType.DOUBLE, shape.asArray(), data.remaining()); + public static Tensor create(long[] shape, DoubleBuffer data) { + Tensor t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining()); t.buffer().asDoubleBuffer().put(data); return t; } @@ -218,8 +218,8 @@ public final class Tensor implements AutoCloseable { * @param data a buffer containing the tensor data. * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Tensor create(Shape shape, LongBuffer data) { - Tensor t = allocateForBuffer(DataType.INT64, shape.asArray(), data.remaining()); + public static Tensor create(long[] shape, LongBuffer data) { + Tensor t = allocateForBuffer(DataType.INT64, shape, data.remaining()); t.buffer().asLongBuffer().put(data); return t; } @@ -239,7 +239,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ - public static Tensor create(Class type, Shape shape, ByteBuffer data) { + public static Tensor create(Class type, long[] shape, ByteBuffer data) { @SuppressWarnings("unchecked") Tensor ret = (Tensor) create(DataType.fromClass(type), shape, data); return ret; @@ -260,7 +260,7 @@ public final class Tensor implements AutoCloseable { * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ - private static Tensor create(DataType dtype, Shape shape, ByteBuffer data) { + private static Tensor create(DataType dtype, long[] shape, ByteBuffer data) { int nremaining = 0; if (dtype != DataType.STRING) { int elemBytes = elemByteSize(dtype); @@ -274,7 +274,7 @@ public final class Tensor implements AutoCloseable { } else { nremaining = data.remaining(); } - Tensor t = allocateForBuffer(dtype, shape.asArray(), nremaining); + Tensor t = allocateForBuffer(dtype, shape, nremaining); t.buffer().put(data); return t; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java index c6c3117db2..c828d23efc 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java @@ -20,6 +20,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; /** Type-safe factory methods for creating {@link org.tensorflow.Tensor} objects. */ public final class Tensors { private Tensors() {} + /** * Creates a scalar String tensor using the default, UTF-8 encoding. * diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index a3667dfd6e..c71046d983 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -52,7 +52,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, Shape shape, IntBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { + try (Tensor value = Tensor.create(shape.asArray(), data)) { return createWithTensor(scope, value); } } @@ -85,7 +85,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, Shape shape, FloatBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { + try (Tensor value = Tensor.create(shape.asArray(), data)) { return createWithTensor(scope, value); } } @@ -118,7 +118,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, Shape shape, DoubleBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { + try (Tensor value = Tensor.create(shape.asArray(), data)) { return createWithTensor(scope, value); } } @@ -151,7 +151,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant create(Scope scope, Shape shape, LongBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { + try (Tensor value = Tensor.create(shape.asArray(), data)) { return createWithTensor(scope, value); } } @@ -226,7 +226,7 @@ public final class Constant extends PrimitiveOp implements Operand { * buffer */ public static Constant create(Scope scope, Class type, Shape shape, ByteBuffer data) { - try (Tensor value = Tensor.create(type, shape, data)) { + try (Tensor value = Tensor.create(type, shape.asArray(), data)) { return createWithTensor(scope, value); } } @@ -239,7 +239,7 @@ public final class Constant extends PrimitiveOp implements Operand { * provided. For example: * *

{@code
-   * Constant.create(scope, 7, Integer.class); // returns a constant scalar tensor 7
+   * Constant.create(scope, new int[]{{1, 2}, {3, 4}}, Integer.class); // returns a 2x2 integer matrix
    * }
* * @param scope is a scope used to add the underlying operation. diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java index 7dd35bb21f..5bba594e17 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java @@ -28,7 +28,7 @@ public class Zeros implements Op, Operand { * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. */ public static Zeros create(Scope scope, Class type, Shape shape) { - int numElements = (int) shape.numElements(); + int numElements = shape.numElements(); if (numElements < 0) { throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants"); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java index 177a0789de..63e191cd38 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -169,7 +169,7 @@ public class ConstantTest { @Test public void createStringBuffer() throws IOException { byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; - Shape shape = Shape.unknown(); + Shape shape = Shape.scalar(); // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer, // followed by a varint encoded size, followed by the data. diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java index d32cc09ae3..ab3446b72b 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -15,7 +15,8 @@ limitations under the License. package org.tensorflow.op.core; -import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import org.junit.Test; import org.junit.runner.RunWith; @@ -33,97 +34,109 @@ public class ZerosTest { @Test public void createIntZeros() { - Shape shape = Shape.make(2, 2); - int[] expected = new int[shape.numElements()]; // all zeros - try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); + Shape shape = Shape.make(2, 2); Zeros op = Zeros.create(scope, Integer.class, Shape.make(2, 2)); Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); - int[] actual = new int[result.numElements()]; + int[][] actual = new int[(int)shape.size(0)][(int)shape.size(1)]; result.copyTo(actual); - assertArrayEquals(expected, actual); + for (int i = 0; i < shape.size(0); ++i) { + for (int j = 0; j < shape.size(1); ++j) { + assertEquals(0, actual[i][j]); + } + } } } @Test public void createFloatZeros() { - Shape shape = Shape.make(2, 2); - float[] expected = new float[shape.numElements()]; // all zeros - try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); + Shape shape = Shape.make(2, 2); Zeros op = Zeros.create(scope, Float.class, Shape.make(2, 2)); Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - float[] actual = new float[shape.numElements()]; + float[][] actual = new float[(int)shape.size(0)][(int)shape.size(1)]; result.copyTo(actual); - assertArrayEquals(expected, actual, EPSILON); + for (int i = 0; i < shape.size(0); ++i) { + for (int j = 0; j < shape.size(1); ++j) { + assertEquals(0.0f, actual[i][j], EPSILON); + } + } } } @Test public void createDoubleZeros() { - Shape shape = Shape.make(2, 2); - double[] expected = new double[shape.numElements()]; // all zeros - try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); + Shape shape = Shape.make(2, 2); Zeros op = Zeros.create(scope, Double.class, Shape.make(2, 2)); Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - double[] actual = new double[shape.numElements()]; + double[][] actual = new double[(int)shape.size(0)][(int)shape.size(1)]; result.copyTo(actual); - assertArrayEquals(expected, actual, EPSILON); + for (int i = 0; i < shape.size(0); ++i) { + for (int j = 0; j < shape.size(1); ++j) { + assertEquals(0.0, actual[i][j], EPSILON); + } + } } } @Test public void createLongZeros() { - Shape shape = Shape.make(2, 2); - float[] expected = new float[shape.numElements()]; // all zeros - try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); + Shape shape = Shape.make(2, 2); Zeros op = Zeros.create(scope, Long.class, Shape.make(2, 2)); Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - float[] actual = new float[shape.numElements()]; + long[][] actual = new long[(int)shape.size(0)][(int)shape.size(1)]; result.copyTo(actual); - assertArrayEquals(expected, actual, 0.0f); + for (int i = 0; i < shape.size(0); ++i) { + for (int j = 0; j < shape.size(1); ++j) { + assertEquals(0L, actual[i][j]); + } + } } } @Test public void createBooleanZeros() { - Shape shape = Shape.make(2, 2); - boolean[] expected = new boolean[shape.numElements()]; // all zeros - try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); + Shape shape = Shape.make(2, 2); Zeros op = Zeros.create(scope, Boolean.class, Shape.make(2, 2)); Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class); - boolean[] actual = new boolean[shape.numElements()]; + boolean[][] actual = new boolean[(int)shape.size(0)][(int)shape.size(1)]; result.copyTo(actual); - assertArrayEquals(expected, actual); + for (int i = 0; i < shape.size(0); ++i) { + for (int j = 0; j < shape.size(1); ++j) { + assertFalse(actual[i][j]); + } + } } } @Test public void createUInt8Zeros() { - Shape shape = Shape.make(2, 2); - byte[] expected = new byte[shape.numElements()]; // all zeros - try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); + Shape shape = Shape.make(2, 2); Zeros op = Zeros.create(scope, UInt8.class, Shape.make(2, 2)); Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class); - byte[] actual = new byte[shape.numElements()]; + byte[][] actual = new byte[(int)shape.size(0)][(int)shape.size(1)]; result.copyTo(actual); - assertArrayEquals(expected, actual); + for (int i = 0; i < shape.size(0); ++i) { + for (int j = 0; j < shape.size(1); ++j) { + assertEquals(0, actual[i][j]); + } + } } } -- cgit v1.2.3