diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java | 113 |
1 files changed, 59 insertions, 54 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java index ab3446b72b..24339d92e6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -18,12 +18,13 @@ package org.tensorflow.op.core; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import java.util.List; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Shape; import org.tensorflow.Tensor; import org.tensorflow.op.Scope; import org.tensorflow.types.UInt8; @@ -37,14 +38,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros<Integer> op = Zeros.create(scope, Integer.class, Shape.make(2, 2)); - Tensor<Integer> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Integer.class); - int[][] actual = new int[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0, actual[i][j]); + long[] shape = {2, 2}; + Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } } } } @@ -55,14 +56,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros<Float> op = Zeros.create(scope, Float.class, Shape.make(2, 2)); - Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - float[][] actual = new float[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0.0f, actual[i][j], EPSILON); + long[] shape = {2, 2}; + Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0f, actual[i][j], EPSILON); + } } } } @@ -73,14 +74,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros<Double> op = Zeros.create(scope, Double.class, Shape.make(2, 2)); - Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - double[][] actual = new double[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0.0, actual[i][j], EPSILON); + long[] shape = {2, 2}; + Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0, actual[i][j], EPSILON); + } } } } @@ -91,14 +92,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros<Long> op = Zeros.create(scope, Long.class, Shape.make(2, 2)); - Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - long[][] actual = new long[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0L, actual[i][j]); + long[] shape = {2, 2}; + Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0L, actual[i][j]); + } } } } @@ -109,14 +110,14 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros<Boolean> op = Zeros.create(scope, Boolean.class, Shape.make(2, 2)); - Tensor<Boolean> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Boolean.class); - boolean[][] actual = new boolean[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertFalse(actual[i][j]); + long[] shape = {2, 2}; + Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertFalse(actual[i][j]); + } } } } @@ -127,14 +128,15 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Shape shape = Shape.make(2, 2); - Zeros<UInt8> op = Zeros.create(scope, UInt8.class, Shape.make(2, 2)); - Tensor<UInt8> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(UInt8.class); - byte[][] actual = new byte[(int)shape.size(0)][(int)shape.size(1)]; - result.copyTo(actual); - for (int i = 0; i < shape.size(0); ++i) { - for (int j = 0; j < shape.size(1); ++j) { - assertEquals(0, actual[i][j]); + long[] shape = {2, 2}; + Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]); + result.copyTo(actual); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } } } } @@ -145,16 +147,19 @@ public class ZerosTest { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Zeros.create(scope, String.class, Shape.make(2, 2)); + long[] shape = {2, 2}; + Zeros.create(scope, Constant.create(scope, shape), String.class); } } - - @Test(expected = IllegalArgumentException.class) - public void cannotCreateZerosWithUnknownDimensions() { + + @Test + public void operationsComposingZerosAreCorrectlyNamed() { try (Graph g = new Graph(); Session sess = new Session(g)) { Scope scope = new Scope(g); - Zeros.create(scope, Float.class, Shape.make(2, -1)); + long[] shape = {2, 2}; + Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class); + List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } } |