From babd79185946130e86ccc3176b7071db5f274309 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Thu, 19 Jul 2018 22:25:03 -0400 Subject: Add zeros operator and unit tests --- .../src/main/java/org/tensorflow/DataType.java | 35 ++++- .../java/src/main/java/org/tensorflow/Shape.java | 22 +++ .../java/src/main/java/org/tensorflow/Tensor.java | 15 +-- .../main/java/org/tensorflow/op/core/Constant.java | 93 +++++++------ .../main/java/org/tensorflow/op/core/Zeros.java | 56 ++++++++ .../java/org/tensorflow/op/core/ConstantTest.java | 82 ++++++++++-- .../java/org/tensorflow/op/core/ZerosTest.java | 147 +++++++++++++++++++++ 7 files changed, 383 insertions(+), 67 deletions(-) create mode 100644 tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java create mode 100644 tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java (limited to 'tensorflow') diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java index 7b92be6d38..ded09974a4 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java @@ -17,21 +17,22 @@ package org.tensorflow; import java.util.HashMap; import java.util.Map; + import org.tensorflow.types.UInt8; /** Represents the type of elements in a {@link Tensor} as an enum. */ public enum DataType { /** 32-bit single precision floating point. */ - FLOAT(1), + FLOAT(1, 4), /** 64-bit double precision floating point. */ - DOUBLE(2), + DOUBLE(2, 8), /** 32-bit signed integer. */ - INT32(3), + INT32(3, 4), /** 8-bit unsigned integer. */ - UINT8(4), + UINT8(4, 1), /** * A sequence of bytes. @@ -41,16 +42,36 @@ public enum DataType { STRING(7), /** 64-bit signed integer. */ - INT64(9), + INT64(9, 8), /** Boolean. */ - BOOL(10); + BOOL(10, 1); private final int value; + + private final int sizeInBytes; - // The integer value must match the corresponding TF_* value in the TensorFlow C API. + /** + * @param value must match the corresponding TF_* value in the TensorFlow C API. + */ DataType(int value) { + this(value, -1); + } + + /** + * @param value must match the corresponding TF_* value in the TensorFlow C API. + * @param sizeInBytes size of an element of this type, in bytes, -1 if unknown + */ + DataType(int value, int sizeInBytes) { this.value = value; + this.sizeInBytes = sizeInBytes; + } + + /** + * @return size of an element of this type, in bytes, or -1 if element size is variable + */ + public int sizeInBytes() { + return sizeInBytes; } /** Corresponding value of the TF_DataType enum in the TensorFlow C API. */ diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java index d533c3d480..1662a49cb7 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java @@ -77,6 +77,28 @@ public final class Shape { return shape[i]; } + /** + * The total number of elements found in a tensor of this shape. + * + *

If the size of some dimensions is unknown, the total number of elements cannot be calculated and -1 is returned. + * + * @return the number of elements or -1 if size of some dimension are unknown + */ + public int numElements() { + if (shape == null) { + return -1; + } + long total = 1; + for (int i = 0; i < shape.length; ++i) { + long size = size(i); + if (size < 0) { + return -1; + } + total *= size; + } + return total; + } + @Override public int hashCode() { return Arrays.hashCode(shape); diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 6e82efdf53..a307269ab5 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -595,20 +595,11 @@ public final class Tensor implements AutoCloseable { } private static int elemByteSize(DataType dataType) { - switch (dataType) { - case FLOAT: - case INT32: - return 4; - case DOUBLE: - case INT64: - return 8; - case BOOL: - case UINT8: - return 1; - case STRING: + int size = dataType.sizeInBytes(); + if (size < 0) { throw new IllegalArgumentException("STRING tensors do not have a fixed element size"); } - throw new IllegalArgumentException("DataType " + dataType + " is not supported yet"); + return size; } private static void throwExceptionIfNotByteOfByteArrays(Object array) { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index bcf165346c..a3667dfd6e 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -38,22 +38,21 @@ import org.tensorflow.op.annotation.Operator; public final class Constant extends PrimitiveOp implements Operand { /** - * Create a constant from a Java object. - * - *

The argument {@code object} is first converted into a Tensor using {@link - * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be - * provided. For example: + * Create a {@link DataType#INT32} constant with data from the given buffer. * - *

{@code
-   * Constant.create(scope, 7); // returns a constant scalar tensor 7
-   * }
+ *

Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. * * @param scope is a scope used to add the underlying operation. - * @param object a Java object representing the constant. - * @see org.tensorflow.Tensor#create(Object) Tensor.create + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @return an integer constant + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Object object, Class type) { - try (Tensor value = Tensor.create(object, type)) { + public static Constant create(Scope scope, Shape shape, IntBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -63,6 +62,7 @@ public final class Constant extends PrimitiveOp implements Operand { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return an integer constant */ public static Constant create(Scope scope, int data) { try (Tensor value = Tensors.create(data)) { @@ -71,7 +71,7 @@ public final class Constant extends PrimitiveOp implements Operand { } /** - * Create a {@link DataType#INT32} constant with data from the given buffer. + * Create a {@link DataType#FLOAT} constant with data from the given buffer. * *

Creates a constant with the given shape by copying elements from the buffer (starting from * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents @@ -81,10 +81,11 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a float constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Shape shape, IntBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { + public static Constant create(Scope scope, Shape shape, FloatBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -94,6 +95,7 @@ public final class Constant extends PrimitiveOp implements Operand { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a float constant */ public static Constant create(Scope scope, float data) { try (Tensor value = Tensors.create(data)) { @@ -102,7 +104,7 @@ public final class Constant extends PrimitiveOp implements Operand { } /** - * Create a {@link DataType#FLOAT} constant with data from the given buffer. + * Create a {@link DataType#DOUBLE} constant with data from the given buffer. * *

Creates a constant with the given shape by copying elements from the buffer (starting from * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents @@ -112,10 +114,11 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a double constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Shape shape, FloatBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { + public static Constant create(Scope scope, Shape shape, DoubleBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -125,6 +128,7 @@ public final class Constant extends PrimitiveOp implements Operand { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a double constant */ public static Constant create(Scope scope, double data) { try (Tensor value = Tensors.create(data)) { @@ -133,7 +137,7 @@ public final class Constant extends PrimitiveOp implements Operand { } /** - * Create a {@link DataType#DOUBLE} constant with data from the given buffer. + * Create a {@link DataType#INT64} constant with data from the given buffer. * *

Creates a constant with the given shape by copying elements from the buffer (starting from * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents @@ -143,10 +147,11 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public static Constant create(Scope scope, Shape shape, DoubleBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { + public static Constant create(Scope scope, Shape shape, LongBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { return createWithTensor(scope, value); } } @@ -156,6 +161,7 @@ public final class Constant extends PrimitiveOp implements Operand { * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a long constant */ public static Constant create(Scope scope, long data) { try (Tensor value = Tensors.create(data)) { @@ -163,30 +169,12 @@ public final class Constant extends PrimitiveOp implements Operand { } } - /** - * Create a {@link DataType#INT64} constant with data from the given buffer. - * - *

Creates a constant with the given shape by copying elements from the buffer (starting from - * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents - * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this - * method. - * - * @param scope is a scope used to add the underlying operation. - * @param shape the tensor shape. - * @param data a buffer containing the tensor data. - * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer - */ - public static Constant create(Scope scope, Shape shape, LongBuffer data) { - try (Tensor value = Tensor.create(shape, data)) { - return createWithTensor(scope, value); - } - } - /** * Creates a constant containing a single {@code boolean} element. * * @param scope is a scope used to add the underlying operation. * @param data The value to put into the new constant. + * @return a boolean constant */ public static Constant create(Scope scope, boolean data) { try (Tensor value = Tensors.create(data)) { @@ -199,6 +187,7 @@ public final class Constant extends PrimitiveOp implements Operand { * * @param scope is a scope used to add the underlying operation. * @param data The string to put into the new constant. + * @return a string constant */ public static Constant create(Scope scope, String data) { try (Tensor value = Tensors.create(data)) { @@ -212,6 +201,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @param scope is a scope used to add the underlying operation. * @param charset The encoding from String to bytes. * @param data The string to put into the new constant. + * @return a string constant */ public static Constant create(Scope scope, String data, Charset charset) { try (Tensor value = Tensor.create(data.getBytes(charset), String.class)) { @@ -231,6 +221,7 @@ public final class Constant extends PrimitiveOp implements Operand { * @param type the tensor datatype. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a constant of type `type` * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ @@ -240,6 +231,28 @@ public final class Constant extends PrimitiveOp implements Operand { } } + /** + * Create a constant from a Java object. + * + *

The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + *

{@code
+   * Constant.create(scope, 7, Integer.class); // returns a constant scalar tensor 7
+   * }
+ * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @return a constant of type `type` + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static Constant create(Scope scope, Object object, Class type) { + try (Tensor value = Tensor.create(object, type)) { + return createWithTensor(scope, value); + } + } + private static Constant createWithTensor(Scope scope, Tensor value) { return new Constant( scope diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java new file mode 100644 index 0000000000..7dd35bb21f --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java @@ -0,0 +1,56 @@ +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * An operator creating a constant initialized with zeros w.r.t its type and shape. + * + * @param constant type + */ +@Operator +public class Zeros implements Op, Operand { + + /** + * Factory method for this operator + * + * @param scope is a scope used to add the underlying operation. + * @param type the tensor datatype. + * @param shape the tensor shape. + * @return a constant initialized with zeros + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. + */ + public static Zeros create(Scope scope, Class type, Shape shape) { + int numElements = (int) shape.numElements(); + if (numElements < 0) { + throw new IllegalArgumentException("Only shapes with known dimension sizes can be used with zeroed constants"); + } + int sizeInBytes = DataType.fromClass(type).sizeInBytes(); + if (sizeInBytes < 0) { + throw new IllegalArgumentException(type.getSimpleName() + " constants cannot be initialized with zeros"); + } + return new Zeros(Constant.create(scope, type, shape, ByteBuffer.allocate(numElements * sizeInBytes))); + } + + @Override + public Output asOutput() { + return constant.asOutput(); + } + + public Constant constant() { + return constant; + } + + private final Constant constant; + + private Zeros(Constant constant) { + this.constant = constant; + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java index ca54214e06..177a0789de 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.op.core; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.ByteArrayOutputStream; @@ -26,38 +27,65 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.tensorflow.Graph; import org.tensorflow.Session; +import org.tensorflow.Shape; import org.tensorflow.Tensor; import org.tensorflow.op.Scope; @RunWith(JUnit4.class) public class ConstantTest { private static final float EPSILON = 1e-7f; + + @Test + public void createInt() { + int value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); + assertEquals(value, result.intValue()); + } + } @Test public void createIntBuffer() { int[] ints = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints)); - Tensor result = sess.runner().fetch(op.asOutput()) - .run().get(0).expect(Integer.class); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); int[] actual = new int[ints.length]; assertArrayEquals(ints, result.copyTo(actual)); } } + @Test + public void createFloat() { + float value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); + assertEquals(value, result.floatValue(), 0.0f); + } + } + @Test public void createFloatBuffer() { float[] floats = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { @@ -69,10 +97,23 @@ public class ConstantTest { } } + @Test + public void createDouble() { + double value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); + assertEquals(value, result.doubleValue(), 0.0); + } + } + @Test public void createDoubleBuffer() { double[] doubles = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { @@ -84,10 +125,23 @@ public class ConstantTest { } } + @Test + public void createLong() { + long value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); + assertEquals(value, result.longValue()); + } + } + @Test public void createLongBuffer() { long[] longs = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { @@ -100,10 +154,22 @@ public class ConstantTest { } @Test - public void createStringBuffer() throws IOException { + public void createBoolean() { + boolean value = true; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, value); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class); + assertEquals(value, result.booleanValue()); + } + } + @Test + public void createStringBuffer() throws IOException { byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; - long[] shape = {}; + Shape shape = Shape.unknown(); // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer, // followed by a varint encoded size, followed by the data. diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java new file mode 100644 index 0000000000..d32cc09ae3 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -0,0 +1,147 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertArrayEquals; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Shape; +import org.tensorflow.Tensor; +import org.tensorflow.op.Scope; +import org.tensorflow.types.UInt8; + +@RunWith(JUnit4.class) +public class ZerosTest { + private static final float EPSILON = 1e-7f; + + @Test + public void createIntZeros() { + Shape shape = Shape.make(2, 2); + int[] expected = new int[shape.numElements()]; // all zeros + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros op = Zeros.create(scope, Integer.class, Shape.make(2, 2)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); + int[] actual = new int[result.numElements()]; + result.copyTo(actual); + assertArrayEquals(expected, actual); + } + } + + @Test + public void createFloatZeros() { + Shape shape = Shape.make(2, 2); + float[] expected = new float[shape.numElements()]; // all zeros + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros op = Zeros.create(scope, Float.class, Shape.make(2, 2)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); + float[] actual = new float[shape.numElements()]; + result.copyTo(actual); + assertArrayEquals(expected, actual, EPSILON); + } + } + + @Test + public void createDoubleZeros() { + Shape shape = Shape.make(2, 2); + double[] expected = new double[shape.numElements()]; // all zeros + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros op = Zeros.create(scope, Double.class, Shape.make(2, 2)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); + double[] actual = new double[shape.numElements()]; + result.copyTo(actual); + assertArrayEquals(expected, actual, EPSILON); + } + } + + @Test + public void createLongZeros() { + Shape shape = Shape.make(2, 2); + float[] expected = new float[shape.numElements()]; // all zeros + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros op = Zeros.create(scope, Long.class, Shape.make(2, 2)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); + float[] actual = new float[shape.numElements()]; + result.copyTo(actual); + assertArrayEquals(expected, actual, 0.0f); + } + } + + @Test + public void createBooleanZeros() { + Shape shape = Shape.make(2, 2); + boolean[] expected = new boolean[shape.numElements()]; // all zeros + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros op = Zeros.create(scope, Boolean.class, Shape.make(2, 2)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class); + boolean[] actual = new boolean[shape.numElements()]; + result.copyTo(actual); + assertArrayEquals(expected, actual); + } + } + + @Test + public void createUInt8Zeros() { + Shape shape = Shape.make(2, 2); + byte[] expected = new byte[shape.numElements()]; // all zeros + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros op = Zeros.create(scope, UInt8.class, Shape.make(2, 2)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class); + byte[] actual = new byte[shape.numElements()]; + result.copyTo(actual); + assertArrayEquals(expected, actual); + } + } + + @Test(expected = IllegalArgumentException.class) + public void cannotCreateStringZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros.create(scope, String.class, Shape.make(2, 2)); + } + } + + @Test(expected = IllegalArgumentException.class) + public void cannotCreateZerosWithUnknownDimensions() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Zeros.create(scope, Float.class, Shape.make(2, -1)); + } + } +} -- cgit v1.2.3