diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java | 82 |
1 files changed, 74 insertions, 8 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java index ca54214e06..177a0789de 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.op.core; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.ByteArrayOutputStream; @@ -26,38 +27,65 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.tensorflow.Graph; import org.tensorflow.Session; +import org.tensorflow.Shape; import org.tensorflow.Tensor; import org.tensorflow.op.Scope; @RunWith(JUnit4.class) public class ConstantTest { private static final float EPSILON = 1e-7f; + + @Test + public void createInt() { + int value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Integer> op = Constant.create(scope, value); + Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); + assertEquals(value, result.intValue()); + } + } @Test public void createIntBuffer() { int[] ints = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints)); - Tensor<Integer> result = sess.runner().fetch(op.asOutput()) - .run().get(0).expect(Integer.class); + Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); int[] actual = new int[ints.length]; assertArrayEquals(ints, result.copyTo(actual)); } } @Test + public void createFloat() { + float value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Float> op = Constant.create(scope, value); + Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); + assertEquals(value, result.floatValue(), 0.0f); + } + } + + @Test public void createFloatBuffer() { float[] floats = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { @@ -70,9 +98,22 @@ public class ConstantTest { } @Test + public void createDouble() { + double value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Double> op = Constant.create(scope, value); + Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); + assertEquals(value, result.doubleValue(), 0.0); + } + } + + @Test public void createDoubleBuffer() { double[] doubles = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { @@ -85,9 +126,22 @@ public class ConstantTest { } @Test + public void createLong() { + long value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Long> op = Constant.create(scope, value); + Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); + assertEquals(value, result.longValue()); + } + } + + @Test public void createLongBuffer() { long[] longs = {1, 2, 3, 4}; - long[] shape = {4}; + Shape shape = Shape.make(4); try (Graph g = new Graph(); Session sess = new Session(g)) { @@ -100,10 +154,22 @@ public class ConstantTest { } @Test - public void createStringBuffer() throws IOException { + public void createBoolean() { + boolean value = true; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Boolean> op = Constant.create(scope, value); + Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class); + assertEquals(value, result.booleanValue()); + } + } + @Test + public void createStringBuffer() throws IOException { byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; - long[] shape = {}; + Shape shape = Shape.unknown(); // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer, // followed by a varint encoded size, followed by the data. |