diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/GraphTest.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c540299bdc..c2e52c22c6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -129,4 +130,106 @@ public class GraphTest { // expected exception. } } + + @Test + public void addGradientsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x1); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + Output<Float> y2 = TestUtil.addN(g, y0, x2); + + Output<?>[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); + + try (Tensor<Float> c1 = Tensors.create(3.0f); + Tensor<Float> c2 = Tensors.create(2.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientSumsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); + + try (Tensor<Float> c = Tensors.create(3.0f); + Tensor<?> output = s.runner() + .feed(x, c) + .fetch(grad[0]) + .run() + .get(0)) { + + assertEquals(114.0f, output.floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientsWithInitialValuesToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Output<?>[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); + + try (Tensor<Float> c = Tensors.create(3.0f); + Tensor<?> output = s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run() + .get(0)) { + + assertEquals(108.0f, output.floatValue(), 0.0f); + } + } + } + + private static Output<?>[] toArray(Output<?>... outputs) { + return outputs; + } } |