aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-07-19 22:25:03 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-08-02 00:34:32 -0400
commitbabd79185946130e86ccc3176b7071db5f274309 (patch)
tree3c7f96bc68bf29021644cd09c5297b9f9e8e71a7 /tensorflow/java
parent3359a5fdedb9988ed53879c85e63259d9cefc889 (diff)
Add zeros operator and unit tests
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/DataType.java35
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Shape.java22
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java15
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java93
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java56
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java82
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java147
7 files changed, 383 insertions, 67 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
index 7b92be6d38..ded09974a4 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
@@ -17,21 +17,22 @@ package org.tensorflow;
import java.util.HashMap;
import java.util.Map;
+
import org.tensorflow.types.UInt8;
/** Represents the type of elements in a {@link Tensor} as an enum. */
public enum DataType {
/** 32-bit single precision floating point. */
- FLOAT(1),
+ FLOAT(1, 4),
/** 64-bit double precision floating point. */
- DOUBLE(2),
+ DOUBLE(2, 8),
/** 32-bit signed integer. */
- INT32(3),
+ INT32(3, 4),
/** 8-bit unsigned integer. */
- UINT8(4),
+ UINT8(4, 1),
/**
* A sequence of bytes.
@@ -41,16 +42,36 @@ public enum DataType {
STRING(7),
/** 64-bit signed integer. */
- INT64(9),
+ INT64(9, 8),
/** Boolean. */
- BOOL(10);
+ BOOL(10, 1);
private final int value;
+
+ private final int sizeInBytes;
- // The integer value must match the corresponding TF_* value in the TensorFlow C API.
+ /**
+ * @param value must match the corresponding TF_* value in the TensorFlow C API.
+ */
DataType(int value) {
+ this(value, -1);
+ }
+
+ /**
+ * @param value must match the corresponding TF_* value in the TensorFlow C API.
+ * @param sizeInBytes size of an element of this type, in bytes, -1 if unknown
+ */
+ DataType(int value, int sizeInBytes) {
this.value = value;
+ this.sizeInBytes = sizeInBytes;
+ }
+
+ /**
+ * @return size of an element of this type, in bytes, or -1 if element size is variable
+ */
+ public int sizeInBytes() {
+ return sizeInBytes;
}
/** Corresponding value of the TF_DataType enum in the TensorFlow C API. */
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
index d533c3d480..1662a49cb7 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
@@ -77,6 +77,28 @@ public final class Shape {
return shape[i];
}
+ /**
+ * The total number of elements found in a tensor of this shape.
+ *
+ * <p>If the size of some dimensions is unknown, the total number of elements cannot be calculated and -1 is returned.
+ *
+ * @return the number of elements or -1 if size of some dimension are unknown
+ */
+ public int numElements() {
+ if (shape == null) {
+ return -1;
+ }
+ long total = 1;
+ for (int i = 0; i < shape.length; ++i) {
+ long size = size(i);
+ if (size < 0) {
+ return -1;
+ }
+ total *= size;
+ }
+ return total;
+ }
+
@Override
public int hashCode() {
return Arrays.hashCode(shape);
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 6e82efdf53..a307269ab5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -595,20 +595,11 @@ public final class Tensor<T> implements AutoCloseable {
}
private static int elemByteSize(DataType dataType) {
- switch (dataType) {
- case FLOAT:
- case INT32:
- return 4;
- case DOUBLE:
- case INT64:
- return 8;
- case BOOL:
- case UINT8:
- return 1;
- case STRING:
+ int size = dataType.sizeInBytes();
+ if (size < 0) {
throw new IllegalArgumentException("STRING tensors do not have a fixed element size");
}
- throw new IllegalArgumentException("DataType " + dataType + " is not supported yet");
+ return size;
}
private static void throwExceptionIfNotByteOfByteArrays(Object array) {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index bcf165346c..a3667dfd6e 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -38,22 +38,21 @@ import org.tensorflow.op.annotation.Operator;
public final class Constant<T> extends PrimitiveOp implements Operand<T> {
/**
- * Create a constant from a Java object.
- *
- * <p>The argument {@code object} is first converted into a Tensor using {@link
- * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
- * provided. For example:
+ * Create a {@link DataType#INT32} constant with data from the given buffer.
*
- * <pre>{@code
- * Constant.create(scope, 7); // returns a constant scalar tensor 7
- * }</pre>
+ * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
+ * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
+ * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
+ * method.
*
* @param scope is a scope used to add the underlying operation.
- * @param object a Java object representing the constant.
- * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @return an integer constant
+ * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
- try (Tensor<T> value = Tensor.create(object, type)) {
+ public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) {
+ try (Tensor<Integer> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -63,6 +62,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return an integer constant
*/
public static Constant<Integer> create(Scope scope, int data) {
try (Tensor<Integer> value = Tensors.create(data)) {
@@ -71,7 +71,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#INT32} constant with data from the given buffer.
+ * Create a {@link DataType#FLOAT} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
@@ -81,10 +81,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a float constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Integer> create(Scope scope, Shape shape, IntBuffer data) {
- try (Tensor<Integer> value = Tensor.create(shape, data)) {
+ public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) {
+ try (Tensor<Float> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -94,6 +95,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a float constant
*/
public static Constant<Float> create(Scope scope, float data) {
try (Tensor<Float> value = Tensors.create(data)) {
@@ -102,7 +104,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#FLOAT} constant with data from the given buffer.
+ * Create a {@link DataType#DOUBLE} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
@@ -112,10 +114,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a double constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Float> create(Scope scope, Shape shape, FloatBuffer data) {
- try (Tensor<Float> value = Tensor.create(shape, data)) {
+ public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) {
+ try (Tensor<Double> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -125,6 +128,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a double constant
*/
public static Constant<Double> create(Scope scope, double data) {
try (Tensor<Double> value = Tensors.create(data)) {
@@ -133,7 +137,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#DOUBLE} constant with data from the given buffer.
+ * Create a {@link DataType#INT64} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
* its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
@@ -143,10 +147,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a long constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
- public static Constant<Double> create(Scope scope, Shape shape, DoubleBuffer data) {
- try (Tensor<Double> value = Tensor.create(shape, data)) {
+ public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) {
+ try (Tensor<Long> value = Tensor.create(shape, data)) {
return createWithTensor(scope, value);
}
}
@@ -156,6 +161,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a long constant
*/
public static Constant<Long> create(Scope scope, long data) {
try (Tensor<Long> value = Tensors.create(data)) {
@@ -164,29 +170,11 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
- * Create a {@link DataType#INT64} constant with data from the given buffer.
- *
- * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
- * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
- * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
- * method.
- *
- * @param scope is a scope used to add the underlying operation.
- * @param shape the tensor shape.
- * @param data a buffer containing the tensor data.
- * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
- */
- public static Constant<Long> create(Scope scope, Shape shape, LongBuffer data) {
- try (Tensor<Long> value = Tensor.create(shape, data)) {
- return createWithTensor(scope, value);
- }
- }
-
- /**
* Creates a constant containing a single {@code boolean} element.
*
* @param scope is a scope used to add the underlying operation.
* @param data The value to put into the new constant.
+ * @return a boolean constant
*/
public static Constant<Boolean> create(Scope scope, boolean data) {
try (Tensor<Boolean> value = Tensors.create(data)) {
@@ -199,6 +187,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
*
* @param scope is a scope used to add the underlying operation.
* @param data The string to put into the new constant.
+ * @return a string constant
*/
public static Constant<String> create(Scope scope, String data) {
try (Tensor<String> value = Tensors.create(data)) {
@@ -212,6 +201,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param charset The encoding from String to bytes.
* @param data The string to put into the new constant.
+ * @return a string constant
*/
public static Constant<String> create(Scope scope, String data, Charset charset) {
try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) {
@@ -231,6 +221,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param type the tensor datatype.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a constant of type `type`
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
@@ -240,6 +231,28 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
}
+ /**
+ * Create a constant from a Java object.
+ *
+ * <p>The argument {@code object} is first converted into a Tensor using {@link
+ * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
+ * provided. For example:
+ *
+ * <pre>{@code
+ * Constant.create(scope, 7, Integer.class); // returns a constant scalar tensor 7
+ * }</pre>
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param object a Java object representing the constant.
+ * @return a constant of type `type`
+ * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ */
+ public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(object, type)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) {
return new Constant<T>(
scope
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
new file mode 100644
index 0000000000..7dd35bb21f
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
@@ -0,0 +1,56 @@
+package org.tensorflow.op.core;
+
+import java.nio.ByteBuffer;
+
+import org.tensorflow.DataType;
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.Shape;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * An operator creating a constant initialized with zeros w.r.t its type and shape.
+ *
+ * @param <T> constant type
+ */
+@Operator
+public class Zeros<T> implements Op, Operand<T> {
+
+ /**
+ * Factory method for this operator
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param type the tensor datatype.
+ * @param shape the tensor shape.
+ * @return a constant initialized with zeros
+ * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
+ */
+ public static <T> Zeros<T> create(Scope scope, Class<T> type, Shape shape) {
+ int numElements = (int) shape.numElements();
+ if (numElements < 0) {
+ throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants");
+ }
+ int sizeInBytes = DataType.fromClass(type).sizeInBytes();
+ if (sizeInBytes < 0) {
+ throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros");
+ }
+ return new Zeros<T>(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes)));
+ }
+
+ @Override
+ public Output<T> asOutput() {
+ return constant.asOutput();
+ }
+
+ public Constant<T> constant() {
+ return constant;
+ }
+
+ private final Constant<T> constant;
+
+ private Zeros(Constant<T> constant) {
+ this.constant = constant;
+ }
+}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index ca54214e06..177a0789de 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.op.core;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayOutputStream;
@@ -26,38 +27,65 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.Graph;
import org.tensorflow.Session;
+import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.op.Scope;
@RunWith(JUnit4.class)
public class ConstantTest {
private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createInt() {
+ int value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Integer> op = Constant.create(scope, value);
+ Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
+ assertEquals(value, result.intValue());
+ }
+ }
@Test
public void createIntBuffer() {
int[] ints = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
- Tensor<Integer> result = sess.runner().fetch(op.asOutput())
- .run().get(0).expect(Integer.class);
+ Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
int[] actual = new int[ints.length];
assertArrayEquals(ints, result.copyTo(actual));
}
}
@Test
+ public void createFloat() {
+ float value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Float> op = Constant.create(scope, value);
+ Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
+ assertEquals(value, result.floatValue(), 0.0f);
+ }
+ }
+
+ @Test
public void createFloatBuffer() {
float[] floats = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
@@ -70,9 +98,22 @@ public class ConstantTest {
}
@Test
+ public void createDouble() {
+ double value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Double> op = Constant.create(scope, value);
+ Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
+ assertEquals(value, result.doubleValue(), 0.0);
+ }
+ }
+
+ @Test
public void createDoubleBuffer() {
double[] doubles = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
@@ -85,9 +126,22 @@ public class ConstantTest {
}
@Test
+ public void createLong() {
+ long value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Long> op = Constant.create(scope, value);
+ Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
+ assertEquals(value, result.longValue());
+ }
+ }
+
+ @Test
public void createLongBuffer() {
long[] longs = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
@@ -100,10 +154,22 @@ public class ConstantTest {
}
@Test
- public void createStringBuffer() throws IOException {
+ public void createBoolean() {
+ boolean value = true;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Boolean> op = Constant.create(scope, value);
+ Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class);
+ assertEquals(value, result.booleanValue());
+ }
+ }
+ @Test
+ public void createStringBuffer() throws IOException {
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
- long[] shape = {};
+ Shape shape = Shape.unknown();
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
// followed by a varint encoded size, followed by the data.
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
new file mode 100644
index 0000000000..d32cc09ae3
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
@@ -0,0 +1,147 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertArrayEquals;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.Graph;
+import org.tensorflow.Session;
+import org.tensorflow.Shape;
+import org.tensorflow.Tensor;
+import org.tensorflow.op.Scope;
+import org.tensorflow.types.UInt8;
+
+@RunWith(JUnit4.class)
+public class ZerosTest {
+ private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createIntZeros() {
+ Shape shape = Shape.make(2, 2);
+ int[] expected = new int[shape.numElements()]; // all zeros
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros<Integer> op = Zeros.create(scope, Integer.class, Shape.make(2, 2));
+ Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
+ int[] actual = new int[result.numElements()];
+ result.copyTo(actual);
+ assertArrayEquals(expected, actual);
+ }
+ }
+
+ @Test
+ public void createFloatZeros() {
+ Shape shape = Shape.make(2, 2);
+ float[] expected = new float[shape.numElements()]; // all zeros
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros<Float> op = Zeros.create(scope, Float.class, Shape.make(2, 2));
+ Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
+ float[] actual = new float[shape.numElements()];
+ result.copyTo(actual);
+ assertArrayEquals(expected, actual, EPSILON);
+ }
+ }
+
+ @Test
+ public void createDoubleZeros() {
+ Shape shape = Shape.make(2, 2);
+ double[] expected = new double[shape.numElements()]; // all zeros
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros<Double> op = Zeros.create(scope, Double.class, Shape.make(2, 2));
+ Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
+ double[] actual = new double[shape.numElements()];
+ result.copyTo(actual);
+ assertArrayEquals(expected, actual, EPSILON);
+ }
+ }
+
+ @Test
+ public void createLongZeros() {
+ Shape shape = Shape.make(2, 2);
+ float[] expected = new float[shape.numElements()]; // all zeros
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros<Long> op = Zeros.create(scope, Long.class, Shape.make(2, 2));
+ Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
+ float[] actual = new float[shape.numElements()];
+ result.copyTo(actual);
+ assertArrayEquals(expected, actual, 0.0f);
+ }
+ }
+
+ @Test
+ public void createBooleanZeros() {
+ Shape shape = Shape.make(2, 2);
+ boolean[] expected = new boolean[shape.numElements()]; // all zeros
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros<Boolean> op = Zeros.create(scope, Boolean.class, Shape.make(2, 2));
+ Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class);
+ boolean[] actual = new boolean[shape.numElements()];
+ result.copyTo(actual);
+ assertArrayEquals(expected, actual);
+ }
+ }
+
+ @Test
+ public void createUInt8Zeros() {
+ Shape shape = Shape.make(2, 2);
+ byte[] expected = new byte[shape.numElements()]; // all zeros
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros<UInt8> op = Zeros.create(scope, UInt8.class, Shape.make(2, 2));
+ Tensor<UInt8> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class);
+ byte[] actual = new byte[shape.numElements()];
+ result.copyTo(actual);
+ assertArrayEquals(expected, actual);
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void cannotCreateStringZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros.create(scope, String.class, Shape.make(2, 2));
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void cannotCreateZerosWithUnknownDimensions() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Zeros.create(scope, Float.class, Shape.make(2, -1));
+ }
+ }
+}