aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java113
1 files changed, 59 insertions, 54 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
index ab3446b72b..24339d92e6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
@@ -18,12 +18,13 @@ package org.tensorflow.op.core;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import java.util.List;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.tensorflow.Graph;
import org.tensorflow.Session;
-import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.tensorflow.op.Scope;
import org.tensorflow.types.UInt8;
@@ -37,14 +38,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Integer> op = Zeros.create(scope, Integer.class, Shape.make(2, 2));
- Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class);
- int[][] actual = new int[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0, actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
}
}
}
@@ -55,14 +56,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Float> op = Zeros.create(scope, Float.class, Shape.make(2, 2));
- Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
- float[][] actual = new float[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0.0f, actual[i][j], EPSILON);
+ long[] shape = {2, 2};
+ Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0f, actual[i][j], EPSILON);
+ }
}
}
}
@@ -73,14 +74,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Double> op = Zeros.create(scope, Double.class, Shape.make(2, 2));
- Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
- double[][] actual = new double[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0.0, actual[i][j], EPSILON);
+ long[] shape = {2, 2};
+ Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0, actual[i][j], EPSILON);
+ }
}
}
}
@@ -91,14 +92,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Long> op = Zeros.create(scope, Long.class, Shape.make(2, 2));
- Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
- long[][] actual = new long[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0L, actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0L, actual[i][j]);
+ }
}
}
}
@@ -109,14 +110,14 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<Boolean> op = Zeros.create(scope, Boolean.class, Shape.make(2, 2));
- Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class);
- boolean[][] actual = new boolean[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertFalse(actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertFalse(actual[i][j]);
+ }
}
}
}
@@ -127,14 +128,15 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Shape shape = Shape.make(2, 2);
- Zeros<UInt8> op = Zeros.create(scope, UInt8.class, Shape.make(2, 2));
- Tensor<UInt8> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class);
- byte[][] actual = new byte[(int)shape.size(0)][(int)shape.size(1)];
- result.copyTo(actual);
- for (int i = 0; i < shape.size(0); ++i) {
- for (int j = 0; j < shape.size(1); ++j) {
- assertEquals(0, actual[i][j]);
+ long[] shape = {2, 2};
+ Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]);
+ result.copyTo(actual);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
}
}
}
@@ -145,16 +147,19 @@ public class ZerosTest {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Zeros.create(scope, String.class, Shape.make(2, 2));
+ long[] shape = {2, 2};
+ Zeros.create(scope, Constant.create(scope, shape), String.class);
}
}
-
- @Test(expected = IllegalArgumentException.class)
- public void cannotCreateZerosWithUnknownDimensions() {
+
+ @Test
+ public void operationsComposingZerosAreCorrectlyNamed() {
try (Graph g = new Graph();
Session sess = new Session(g)) {
Scope scope = new Scope(g);
- Zeros.create(scope, Float.class, Shape.make(2, -1));
+ long[] shape = {2, 2};
+ Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class);
+ List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
}
}
}