aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-07-24 08:52:02 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-08-02 00:35:31 -0400
commit4a8cedd26c182b8f866ee3194c4a016d336ec907 (patch)
treef22693da97970d5a5b79da24a13be1204f635081 /tensorflow/java/src
parentbabd79185946130e86ccc3176b7071db5f274309 (diff)
Add unit tests
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Shape.java24
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java22
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensors.java1
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java12
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java75
7 files changed, 79 insertions, 59 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
index 1662a49cb7..a177cdaf7a 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
@@ -74,7 +74,7 @@ public final class Shape {
* @return The size of the requested dimension or -1 if it is unknown.
*/
public long size(int i) {
- return shape[i];
+ return shape == null ? -1 : shape[i];
}
/**
@@ -88,9 +88,11 @@ public final class Shape {
if (shape == null) {
return -1;
}
- long total = 1;
+ int total = 1;
for (int i = 0; i < shape.length; ++i) {
- long size = size(i);
+ // TODO (karllessard): There might be a lossy conversion here from 'long' sizes to 'int' total, but this issue
+ // seems ubiquitous in the current Java client implementation. It should be adressed all at once.
+ int size = (int) size(i);
if (size < 0) {
return -1;
}
@@ -99,6 +101,16 @@ public final class Shape {
return total;
}
+ /**
+ * Returns the shape as an array.
+ *
+ * <p>Each element represent the size of the dimension at the given index. For example,
+ * {@code shape.asArray()[4]} is equal to the size of the fourth dimension in this shape.
+ */
+ public long[] asArray() {
+ return shape;
+ }
+
@Override
public int hashCode() {
return Arrays.hashCode(shape);
@@ -131,12 +143,6 @@ public final class Shape {
this.shape = shape;
}
- // Package-private accessor.
- // The idea is that the public API does not expose the internal array.
- long[] asArray() {
- return shape;
- }
-
private long[] shape;
private boolean hasUnknownDimension() {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index a307269ab5..38bb55e59f 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -164,8 +164,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Integer> create(Shape shape, IntBuffer data) {
- Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape.asArray(), data.remaining());
+ public static Tensor<Integer> create(long[] shape, IntBuffer data) {
+ Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining());
t.buffer().asIntBuffer().put(data);
return t;
}
@@ -182,8 +182,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Float> create(Shape shape, FloatBuffer data) {
- Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape.asArray(), data.remaining());
+ public static Tensor<Float> create(long[] shape, FloatBuffer data) {
+ Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
t.buffer().asFloatBuffer().put(data);
return t;
}
@@ -200,8 +200,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Double> create(Shape shape, DoubleBuffer data) {
- Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape.asArray(), data.remaining());
+ public static Tensor<Double> create(long[] shape, DoubleBuffer data) {
+ Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
t.buffer().asDoubleBuffer().put(data);
return t;
}
@@ -218,8 +218,8 @@ public final class Tensor<T> implements AutoCloseable {
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Tensor<Long> create(Shape shape, LongBuffer data) {
- Tensor<Long> t = allocateForBuffer(DataType.INT64, shape.asArray(), data.remaining());
+ public static Tensor<Long> create(long[] shape, LongBuffer data) {
+ Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining());
t.buffer().asLongBuffer().put(data);
return t;
}
@@ -239,7 +239,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- public static <T> Tensor<T> create(Class<T> type, Shape shape, ByteBuffer data) {
+ public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) {
@SuppressWarnings("unchecked")
Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data);
return ret;
@@ -260,7 +260,7 @@ public final class Tensor<T> implements AutoCloseable {
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
- private static Tensor<?> create(DataType dtype, Shape shape, ByteBuffer data) {
+ private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) {
int nremaining = 0;
if (dtype != DataType.STRING) {
int elemBytes = elemByteSize(dtype);
@@ -274,7 +274,7 @@ public final class Tensor<T> implements AutoCloseable {
} else {
nremaining = data.remaining();
}
- Tensor<?> t = allocateForBuffer(dtype, shape.asArray(), nremaining);
+ Tensor<?> t = allocateForBuffer(dtype, shape, nremaining);
t.buffer().put(data);
return t;
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
index c6c3117db2..c828d23efc 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensors.java
@@ -20,6 +20,7 @@ import static java.nio.charset.StandardCharsets.UTF_8;
/** Type-safe factory methods for creating {@link org.tensorflow.Tensor} objects. */
public final class Tensors {
private Tensors() {}
+
/**
* Creates a scalar String tensor using the default, UTF-8 encoding.
*
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index a3667dfd6e..c71046d983 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -52,7 +52,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) {
- try (Tensor<Integer> value = Tensor.create(shape, data)) {
+ try (Tensor<Integer> value = Tensor.create(shape.asArray(), data)) {
return createWithTensor(scope, value);
}
}
@@ -85,7 +85,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) {
- try (Tensor<Float> value = Tensor.create(shape, data)) {
+ try (Tensor<Float> value = Tensor.create(shape.asArray(), data)) {
return createWithTensor(scope, value);
}
}
@@ -118,7 +118,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) {
- try (Tensor<Double> value = Tensor.create(shape, data)) {
+ try (Tensor<Double> value = Tensor.create(shape.asArray(), data)) {
return createWithTensor(scope, value);
}
}
@@ -151,7 +151,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) {
- try (Tensor<Long> value = Tensor.create(shape, data)) {
+ try (Tensor<Long> value = Tensor.create(shape.asArray(), data)) {
return createWithTensor(scope, value);
}
}
@@ -226,7 +226,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* buffer
*/
public static <T> Constant<T> create(Scope scope, Class<T> type, Shape shape, ByteBuffer data) {
- try (Tensor<T> value = Tensor.create(type, shape, data)) {
+ try (Tensor<T> value = Tensor.create(type, shape.asArray(), data)) {
return createWithTensor(scope, value);
}
}
@@ -239,7 +239,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* provided. For example:
*
* <pre>{@code
- * Constant.create(scope, 7, Integer.class); // returns a constant scalar tensor 7
+ * Constant.create(scope, new int[]{{1, 2}, {3, 4}}, Integer.class); // returns a 2x2 integer matrix
* }</pre>
*
* @param scope is a scope used to add the underlying operation.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
index 7dd35bb21f..5bba594e17 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
@@ -28,7 +28,7 @@ public class Zeros<T> implements Op, Operand<T> {
* @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
*/
public static <T> Zeros<T> create(Scope scope, Class<T> type, Shape shape) {
- int numElements = (int) shape.numElements();
+ int numElements = shape.numElements();
if (numElements < 0) {
throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants");
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index 177a0789de..63e191cd38 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -169,7 +169,7 @@ public class ConstantTest {
@Test
public void createStringBuffer() throws IOException {
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
- Shape shape = Shape.unknown();
+ Shape shape = Shape.scalar();
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
// followed by a varint encoded size, followed by the data.
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
index d32cc09ae3..ab3446b72b 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
@@ -15,7 +15,8 @@ limitations under the License.
package org.tensorflow.op.core;
-import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -33,97 +34,109 @@ public class ZerosTest {
@Test
public void createIntZeros() {
- Shape shape = Shape.make(2, 2);
- int[] expected = new int[shape.numElements()]; // all zeros
-
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
+ Shape shape = Shape.make(2, 2);
Zeros<Integer> op = Zeros.create(scope, Integer.class, Shape.make(2, 2));
Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
- int[] actual = new int[result.numElements()];
+ int[][] actual = new int[(int)shape.size(0)][(int)shape.size(1)];
result.copyTo(actual);
- assertArrayEquals(expected, actual);
+ for (int i = 0; i < shape.size(0); ++i) {
+ for (int j = 0; j < shape.size(1); ++j) {
+ assertEquals(0, actual[i][j]);
+ }
+ }
}
}
@Test
public void createFloatZeros() {
- Shape shape = Shape.make(2, 2);
- float[] expected = new float[shape.numElements()]; // all zeros
-
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
+ Shape shape = Shape.make(2, 2);
Zeros<Float> op = Zeros.create(scope, Float.class, Shape.make(2, 2));
Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
- float[] actual = new float[shape.numElements()];
+ float[][] actual = new float[(int)shape.size(0)][(int)shape.size(1)];
result.copyTo(actual);
- assertArrayEquals(expected, actual, EPSILON);
+ for (int i = 0; i < shape.size(0); ++i) {
+ for (int j = 0; j < shape.size(1); ++j) {
+ assertEquals(0.0f, actual[i][j], EPSILON);
+ }
+ }
}
}
@Test
public void createDoubleZeros() {
- Shape shape = Shape.make(2, 2);
- double[] expected = new double[shape.numElements()]; // all zeros
-
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
+ Shape shape = Shape.make(2, 2);
Zeros<Double> op = Zeros.create(scope, Double.class, Shape.make(2, 2));
Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
- double[] actual = new double[shape.numElements()];
+ double[][] actual = new double[(int)shape.size(0)][(int)shape.size(1)];
result.copyTo(actual);
- assertArrayEquals(expected, actual, EPSILON);
+ for (int i = 0; i < shape.size(0); ++i) {
+ for (int j = 0; j < shape.size(1); ++j) {
+ assertEquals(0.0, actual[i][j], EPSILON);
+ }
+ }
}
}
@Test
public void createLongZeros() {
- Shape shape = Shape.make(2, 2);
- float[] expected = new float[shape.numElements()]; // all zeros
-
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
+ Shape shape = Shape.make(2, 2);
Zeros<Long> op = Zeros.create(scope, Long.class, Shape.make(2, 2));
Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
- float[] actual = new float[shape.numElements()];
+ long[][] actual = new long[(int)shape.size(0)][(int)shape.size(1)];
result.copyTo(actual);
- assertArrayEquals(expected, actual, 0.0f);
+ for (int i = 0; i < shape.size(0); ++i) {
+ for (int j = 0; j < shape.size(1); ++j) {
+ assertEquals(0L, actual[i][j]);
+ }
+ }
}
}
@Test
public void createBooleanZeros() {
- Shape shape = Shape.make(2, 2);
- boolean[] expected = new boolean[shape.numElements()]; // all zeros
-
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
+ Shape shape = Shape.make(2, 2);
Zeros<Boolean> op = Zeros.create(scope, Boolean.class, Shape.make(2, 2));
Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class);
- boolean[] actual = new boolean[shape.numElements()];
+ boolean[][] actual = new boolean[(int)shape.size(0)][(int)shape.size(1)];
result.copyTo(actual);
- assertArrayEquals(expected, actual);
+ for (int i = 0; i < shape.size(0); ++i) {
+ for (int j = 0; j < shape.size(1); ++j) {
+ assertFalse(actual[i][j]);
+ }
+ }
}
}
@Test
public void createUInt8Zeros() {
- Shape shape = Shape.make(2, 2);
- byte[] expected = new byte[shape.numElements()]; // all zeros
-
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
+ Shape shape = Shape.make(2, 2);
Zeros<UInt8> op = Zeros.create(scope, UInt8.class, Shape.make(2, 2));
Tensor<UInt8> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class);
- byte[] actual = new byte[shape.numElements()];
+ byte[][] actual = new byte[(int)shape.size(0)][(int)shape.size(1)];
result.copyTo(actual);
- assertArrayEquals(expected, actual);
+ for (int i = 0; i < shape.size(0); ++i) {
+ for (int j = 0; j < shape.size(1); ++j) {
+ assertEquals(0, actual[i][j]);
+ }
+ }
}
}