diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/SessionTest.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/SessionTest.java | 38 |
1 files changed, 7 insertions, 31 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index e8cc76c2a6..7d5980bcde 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.util.ArrayList; -import java.util.Collection; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,8 +34,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -53,8 +51,8 @@ public class SessionTest { Output<Integer> feed = g.operation("X").output(0); Output<Integer> fetch = g.operation("Y").output(0); try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -112,7 +110,7 @@ public class SessionTest { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs); assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -135,8 +133,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.constant(g, "c1", 2718); TestUtil.constant(g, "c2", 31415); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); assertEquals(31415, outputs.get(0).intValue()); assertEquals(2718, outputs.get(1).intValue()); @@ -164,28 +162,6 @@ public class SessionTest { Session s = new Session(g, singleThreadConfigProto())) {} } - private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> - implements AutoCloseable { - AutoCloseableList(Collection<? extends E> c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } - } - private static byte[] fullTraceRunOptions() { // Ideally this would use the generated Java sources for protocol buffers // and end up with something like the snippet below. However, generating |