aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-06-26 23:30:26 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-06-27 21:48:00 -0400
commit52e32a7b0ea35b52ec3a9ea5d522a08719f26068 (patch)
treeee9587483ded59a187c3ce5845ea872e59d2481a /tensorflow/java/src
parent9b7d92dbad4a18df0c34ff425a1e236f1dd75817 (diff)
Second code review
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java (renamed from tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java)2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java55
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SessionTest.java38
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java25
4 files changed, 65 insertions, 55 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index 097b541501..f4671c8af9 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-package org.tensorflow.op.training;
+package org.tensorflow.op.core;
import java.util.Arrays;
import java.util.Iterator;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index ac867f1e46..3ffc249185 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
-import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -135,7 +134,7 @@ public class GraphTest {
@Test
public void addGradientsToGraph() {
try (Graph g = new Graph();
- Session s = new Session(g)) {
+ Session s = new Session(g)) {
Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
@@ -147,23 +146,27 @@ public class GraphTest {
assertEquals(2, grads.length);
assertEquals(DataType.FLOAT, grads[0].dataType());
assertEquals(DataType.FLOAT, grads[1].dataType());
-
- List<Tensor<?>> outputs = s.runner()
- .feed(x1, Tensors.create(3.0f))
- .feed(x2, Tensors.create(2.0f))
- .fetch(grads[0])
- .fetch(grads[1])
- .run();
+
+ try (Tensor<Float> c1 = Tensors.create(3.0f);
+ Tensor<Float> c2 = Tensors.create(2.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ s.runner()
+ .feed(x1, c1)
+ .feed(x2, c2)
+ .fetch(grads[0])
+ .fetch(grads[1])
+ .run())) {
- assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f);
- assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f);
+ assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f);
+ }
}
}
@Test
public void addGradientSumsToGraph() {
try (Graph g = new Graph();
- Session s = new Session(g)) {
+ Session s = new Session(g)) {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
@@ -171,19 +174,22 @@ public class GraphTest {
Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null);
- List<Tensor<?>> outputs = s.runner()
- .feed(x, Tensors.create(3.0f))
- .fetch(grads[0])
- .run();
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grads[0])
+ .run()
+ .get(0)) {
- assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(114.0f, output.floatValue(), 0.0f);
+ }
}
}
@Test
public void addGradientsWithInitialValuesToGraph() {
try (Graph g = new Graph();
- Session s = new Session(g)) {
+ Session s = new Session(g)) {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
Output<Float> y = TestUtil.square(g, "y", x);
@@ -191,12 +197,15 @@ public class GraphTest {
Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx));
- List<Tensor<?>> outputs = s.runner()
- .feed(x, Tensors.create(3.0f))
- .fetch(grads[0])
- .run();
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grads[0])
+ .run()
+ .get(0)) {
- assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(108.0f, output.floatValue(), 0.0f);
+ }
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
index e8cc76c2a6..7d5980bcde 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
@@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
-import java.util.ArrayList;
-import java.util.Collection;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -36,8 +34,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -53,8 +51,8 @@ public class SessionTest {
Output<Integer> feed = g.operation("X").output(0);
Output<Integer> fetch = g.operation("Y").output(0);
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -112,7 +110,7 @@ public class SessionTest {
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
// Sanity check on outputs.
- AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs);
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -135,8 +133,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.constant(g, "c1", 2718);
TestUtil.constant(g, "c2", 31415);
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
assertEquals(31415, outputs.get(0).intValue());
assertEquals(2718, outputs.get(1).intValue());
@@ -164,28 +162,6 @@ public class SessionTest {
Session s = new Session(g, singleThreadConfigProto())) {}
}
- private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
- implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
- super(c);
- }
-
- @Override
- public void close() {
- Exception toThrow = null;
- for (AutoCloseable c : this) {
- try {
- c.close();
- } catch (Exception e) {
- toThrow = e;
- }
- }
- if (toThrow != null) {
- throw new RuntimeException(toThrow);
- }
- }
- }
-
private static byte[] fullTraceRunOptions() {
// Ideally this would use the generated Java sources for protocol buffers
// and end up with something like the snippet below. However, generating
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 7feb296aed..4e84886416 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -16,9 +16,34 @@ limitations under the License.
package org.tensorflow;
import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.Collection;
/** Static utility functions. */
public class TestUtil {
+
+ public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
+ implements AutoCloseable {
+ AutoCloseableList(Collection<? extends E> c) {
+ super(c);
+ }
+
+ @Override
+ public void close() {
+ Exception toThrow = null;
+ for (AutoCloseable c : this) {
+ try {
+ c.close();
+ } catch (Exception e) {
+ toThrow = e;
+ }
+ }
+ if (toThrow != null) {
+ throw new RuntimeException(toThrow);
+ }
+ }
+ }
+
public static <T> Output<T> constant(Graph g, String name, Object value) {
try (Tensor<?> t = Tensor.create(value)) {
return g.opBuilder("Const", name)