aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
index 2ffc69c209..b75f79a421 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
@@ -108,17 +108,18 @@ public class GradientsTest {
}
@Test
- public void createGradientsWithScopeName() {
+ public void validateGradientsNames() {
try (Graph g = new Graph()) {
- Scope scope = new Scope(g);
+ Scope scope = new Scope(g).withSubScope("sub");
Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
Output<Float> y = TestUtil.square(g, "y", x);
- Scope gradScope = scope.withSubScope("grads").withSubScope("test");
- Gradients grads = Gradients.create(gradScope, y, Arrays.asList(x));
+ Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x));
+ assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/"));
- assertTrue(grads.dy(0).op().name().startsWith("grads/test/"));
+ Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x));
+ assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/"));
}
}
}