diff options
author | 2018-06-26 23:30:26 -0400 | |
---|---|---|
committer | 2018-06-27 21:48:00 -0400 | |
commit | 52e32a7b0ea35b52ec3a9ea5d522a08719f26068 (patch) | |
tree | ee9587483ded59a187c3ce5845ea872e59d2481a /tensorflow/java/src | |
parent | 9b7d92dbad4a18df0c34ff425a1e236f1dd75817 (diff) |
Second code review
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java (renamed from tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java) | 2 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 55 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/SessionTest.java | 38 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/TestUtil.java | 25 |
4 files changed, 65 insertions, 55 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index 097b541501..f4671c8af9 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -package org.tensorflow.op.training; +package org.tensorflow.op.core; import java.util.Arrays; import java.util.Iterator; diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index ac867f1e46..3ffc249185 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; -import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -135,7 +134,7 @@ public class GraphTest { @Test public void addGradientsToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g)) { Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class); Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class); @@ -147,23 +146,27 @@ public class GraphTest { assertEquals(2, grads.length); assertEquals(DataType.FLOAT, grads[0].dataType()); assertEquals(DataType.FLOAT, grads[1].dataType()); - - List<Tensor<?>> outputs = s.runner() - .feed(x1, Tensors.create(3.0f)) - .feed(x2, Tensors.create(2.0f)) - .fetch(grads[0]) - .fetch(grads[1]) - .run(); + + try (Tensor<Float> c1 = Tensors.create(3.0f); + Tensor<Float> c2 = Tensors.create(2.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads[0]) + .fetch(grads[1]) + .run())) { - assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f); - assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f); + } } } @Test public void addGradientSumsToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g)) { Output<Float> x = TestUtil.placeholder(g, "x", Float.class); Output<Float> y0 = TestUtil.square(g, "y0", x); @@ -171,19 +174,22 @@ public class GraphTest { Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null); - List<Tensor<?>> outputs = s.runner() - .feed(x, Tensors.create(3.0f)) - .fetch(grads[0]) - .run(); + try (Tensor<Float> c = Tensors.create(3.0f); + Tensor<?> output = s.runner() + .feed(x, c) + .fetch(grads[0]) + .run() + .get(0)) { - assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(114.0f, output.floatValue(), 0.0f); + } } } @Test public void addGradientsWithInitialValuesToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g)) { Output<Float> x = TestUtil.placeholder(g, "x", Float.class); Output<Float> y = TestUtil.square(g, "y", x); @@ -191,12 +197,15 @@ public class GraphTest { Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx)); - List<Tensor<?>> outputs = s.runner() - .feed(x, Tensors.create(3.0f)) - .fetch(grads[0]) - .run(); + try (Tensor<Float> c = Tensors.create(3.0f); + Tensor<?> output = s.runner() + .feed(x, c) + .fetch(grads[0]) + .run() + .get(0)) { - assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(108.0f, output.floatValue(), 0.0f); + } } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index e8cc76c2a6..7d5980bcde 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.util.ArrayList; -import java.util.Collection; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,8 +34,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -53,8 +51,8 @@ public class SessionTest { Output<Integer> feed = g.operation("X").output(0); Output<Integer> fetch = g.operation("Y").output(0); try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -112,7 +110,7 @@ public class SessionTest { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs); assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -135,8 +133,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.constant(g, "c1", 2718); TestUtil.constant(g, "c2", 31415); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); assertEquals(31415, outputs.get(0).intValue()); assertEquals(2718, outputs.get(1).intValue()); @@ -164,28 +162,6 @@ public class SessionTest { Session s = new Session(g, singleThreadConfigProto())) {} } - private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> - implements AutoCloseable { - AutoCloseableList(Collection<? extends E> c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } - } - private static byte[] fullTraceRunOptions() { // Ideally this would use the generated Java sources for protocol buffers // and end up with something like the snippet below. However, generating diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 7feb296aed..4e84886416 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -16,9 +16,34 @@ limitations under the License. package org.tensorflow; import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Collection; /** Static utility functions. */ public class TestUtil { + + public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> + implements AutoCloseable { + AutoCloseableList(Collection<? extends E> c) { + super(c); + } + + @Override + public void close() { + Exception toThrow = null; + for (AutoCloseable c : this) { + try { + c.close(); + } catch (Exception e) { + toThrow = e; + } + } + if (toThrow != null) { + throw new RuntimeException(toThrow); + } + } + } + public static <T> Output<T> constant(Graph g, String name, Object value) { try (Tensor<?> t = Tensor.create(value)) { return g.opBuilder("Const", name) |