diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/GraphTest.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c2e52c22c6..7c05c1deaf 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; - import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -180,8 +179,8 @@ public class GraphTest { Output<Float> x = TestUtil.placeholder(g, "x", Float.class); Output<Float> y0 = TestUtil.square(g, "y0", x); Output<Float> y1 = TestUtil.square(g, "y1", y0); - - Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + + Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); assertEquals(DataType.FLOAT, grad[0].dataType()); @@ -212,7 +211,7 @@ public class GraphTest { assertEquals(1, grad0.length); assertEquals(DataType.FLOAT, grad0[0].dataType()); - Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); assertEquals(DataType.FLOAT, grad1[0].dataType()); @@ -228,6 +227,33 @@ public class GraphTest { } } } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + + Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad0[0].op().name().startsWith("gradients/")); + + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("gradients_1/")); + + Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad2[0].op().name().startsWith("more_gradients/")); + + Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad3[0].op().name().startsWith("even_more_gradients/")); + + try { + g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + } catch (IllegalArgumentException e) { + // expected exception + } + } + } private static Output<?>[] toArray(Output<?>... outputs) { return outputs; |