diff options
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/GraphTest.java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c2e52c22c6..c02336aebe 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -181,7 +181,7 @@ public class GraphTest { Output<Float> y0 = TestUtil.square(g, "y0", x); Output<Float> y1 = TestUtil.square(g, "y1", y0); - Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); assertEquals(DataType.FLOAT, grad[0].dataType()); @@ -212,7 +212,7 @@ public class GraphTest { assertEquals(1, grad0.length); assertEquals(DataType.FLOAT, grad0[0].dataType()); - Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); assertEquals(DataType.FLOAT, grad1[0].dataType()); @@ -229,6 +229,21 @@ public class GraphTest { } } + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + + Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad0[0].op().name().startsWith("gradients/")); + + Output<?>[] grad1 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("more_gradients/")); + } + } + private static Output<?>[] toArray(Output<?>... outputs) { return outputs; } |