aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java82
1 files changed, 74 insertions, 8 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index ca54214e06..177a0789de 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.op.core;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayOutputStream;
@@ -26,38 +27,65 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.Graph;
import org.tensorflow.Session;
+import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.op.Scope;
@RunWith(JUnit4.class)
public class ConstantTest {
private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createInt() {
+ int value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Integer> op = Constant.create(scope, value);
+ Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
+ assertEquals(value, result.intValue());
+ }
+ }
@Test
public void createIntBuffer() {
int[] ints = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
- Tensor<Integer> result = sess.runner().fetch(op.asOutput())
- .run().get(0).expect(Integer.class);
+ Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
int[] actual = new int[ints.length];
assertArrayEquals(ints, result.copyTo(actual));
}
}
@Test
+ public void createFloat() {
+ float value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Float> op = Constant.create(scope, value);
+ Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
+ assertEquals(value, result.floatValue(), 0.0f);
+ }
+ }
+
+ @Test
public void createFloatBuffer() {
float[] floats = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
@@ -70,9 +98,22 @@ public class ConstantTest {
}
@Test
+ public void createDouble() {
+ double value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Double> op = Constant.create(scope, value);
+ Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
+ assertEquals(value, result.doubleValue(), 0.0);
+ }
+ }
+
+ @Test
public void createDoubleBuffer() {
double[] doubles = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
@@ -85,9 +126,22 @@ public class ConstantTest {
}
@Test
+ public void createLong() {
+ long value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Long> op = Constant.create(scope, value);
+ Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
+ assertEquals(value, result.longValue());
+ }
+ }
+
+ @Test
public void createLongBuffer() {
long[] longs = {1, 2, 3, 4};
- long[] shape = {4};
+ Shape shape = Shape.make(4);
try (Graph g = new Graph();
Session sess = new Session(g)) {
@@ -100,10 +154,22 @@ public class ConstantTest {
}
@Test
- public void createStringBuffer() throws IOException {
+ public void createBoolean() {
+ boolean value = true;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Boolean> op = Constant.create(scope, value);
+ Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class);
+ assertEquals(value, result.booleanValue());
+ }
+ }
+ @Test
+ public void createStringBuffer() throws IOException {
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
- long[] shape = {};
+ Shape shape = Shape.unknown();
// byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
// followed by a varint encoded size, followed by the data.