aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/GraphTest.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java103
1 files changed, 103 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c540299bdc..c2e52c22c6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -129,4 +130,106 @@ public class GraphTest {
// expected exception.
}
}
+
+ @Test
+ public void addGradientsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x1);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+ Output<Float> y2 = TestUtil.addN(g, y0, x2);
+
+ Output<?>[] grads0 = g.addGradients(y1, toArray(x1));
+ assertNotNull(grads0);
+ assertEquals(1, grads0.length);
+ assertEquals(DataType.FLOAT, grads0[0].dataType());
+
+ Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2));
+ assertNotNull(grads1);
+ assertEquals(2, grads1.length);
+ assertEquals(DataType.FLOAT, grads1[0].dataType());
+ assertEquals(DataType.FLOAT, grads1[1].dataType());
+
+ try (Tensor<Float> c1 = Tensors.create(3.0f);
+ Tensor<Float> c2 = Tensors.create(2.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ s.runner()
+ .feed(x1, c1)
+ .feed(x2, c2)
+ .fetch(grads0[0])
+ .fetch(grads1[0])
+ .fetch(grads1[1])
+ .run())) {
+
+ assertEquals(3, outputs.size());
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f);
+ assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientSumsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+ assertNotNull(grad);
+ assertEquals(1, grad.length);
+ assertEquals(DataType.FLOAT, grad[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(114.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientsWithInitialValuesToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad0 = g.addGradients(y1, toArray(y0));
+ assertNotNull(grad0);
+ assertEquals(1, grad0.length);
+ assertEquals(DataType.FLOAT, grad0[0].dataType());
+
+ Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ assertNotNull(grad1);
+ assertEquals(1, grad1.length);
+ assertEquals(DataType.FLOAT, grad1[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad1[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(108.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ private static Output<?>[] toArray(Output<?>... outputs) {
+ return outputs;
+ }
}