From a278365e8848f5fcbccb42f95a3c523367c1602f Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Mon, 16 Jul 2018 00:49:06 -0400 Subject: Enforce uniqueness of custom prefixes for gradients --- .../java/src/main/java/org/tensorflow/Graph.java | 8 ++++++- .../src/main/java/org/tensorflow/op/NameScope.java | 4 ---- .../src/main/java/org/tensorflow/op/Scope.java | 26 ++-------------------- .../java/org/tensorflow/op/core/Gradients.java | 6 ++++- .../src/test/java/org/tensorflow/GraphTest.java | 16 +++++++++++-- .../src/test/java/org/tensorflow/op/ScopeTest.java | 15 ------------- .../java/org/tensorflow/op/core/GradientsTest.java | 11 ++++----- 7 files changed, 34 insertions(+), 52 deletions(-) (limited to 'tensorflow/java') diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 353092701b..32853c1367 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -152,8 +152,14 @@ public final class Graph implements AutoCloseable { *

* If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all * shapes in {@code y}. + *

+ * {@code prefix} is used as the name prefix applied to all nodes added to the graph to compute gradients. It must + * be unique within the provided graph or the operation will fail. + *

+ * If {@code prefix} is null, then the nodes will be added to under the default prefix, which is "gradients" for the + * first invocation, then "gradients_1", "gradients_2", etc. for any subsequent calls to the same graph. * - * @param prefix string prefix applied to names of nodes added to the graph to compute gradients. + * @param prefix unique string prefix applied before the names of nodes added to the graph to compute gradients. * If null, defaults to "gradients". * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java index 95a2a2f9f5..2e84cac1ac 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java @@ -56,10 +56,6 @@ final class NameScope { String actualName = (opName != null) ? opName : name; return fullyQualify(makeUnique(actualName)); } - - String opPrefix() { - return opPrefix; - } /** * Create a new, root-level namescope. diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index cf0b3d98c1..563ea66ef1 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -135,17 +135,8 @@ public final class Scope { * } * *

Note: if you provide a composite operator building class (i.e, a class that adds a - * set of related operations to the graph by calling other operator building code) you should also - * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a - * meaningful name. - * - *

{@code
-   * public static Stddev create(Scope scope, ...) {
-   *   // group sub-operations under a common name
-   *   Scope group = scope.withSubScope("stddev");
-   *   ... Sqrt.create(group, Mean.create(group, ...))
-   * }
-   * }
+ * set of related operations to the graph by calling other operator building code), the provided name + * will act as a subscope to all underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. @@ -154,19 +145,6 @@ public final class Scope { public String makeOpName(String defaultName) { return nameScope.makeOpName(defaultName); } - - /** - * The name prefix of this scope. - *

- * This value is the combination of the name of this scope and all of its parents, seperated by a '/', e.g. - *

{@code
-   * Scope scope = new Scope(graph);
-   * assertEquals(scope.withSubScope("sub1").withSubScope("sub2").prefix(), "sub1/sub2");
-   * }
- */ - public String prefix() { - return nameScope.opPrefix(); - } private Scope(Graph graph, NameScope nameScope) { this.graph = graph; diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index 6d71ddfff0..5432ff244e 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -88,7 +88,11 @@ public class Gradients implements Op, Iterable> { } } } - Output[] dy = scope.graph().addGradients(scope.prefix(), Operands.asOutputs(y), Operands.asOutputs(x), dx); + Output[] dy = scope.graph().addGradients( + scope.makeOpName("Gradients"), + Operands.asOutputs(y), + Operands.asOutputs(x), + dx); return new Gradients(Arrays.asList(dy)); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c02336aebe..56c8f22daa 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -239,8 +239,20 @@ public class GraphTest { Output[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); assertTrue(grad0[0].op().name().startsWith("gradients/")); - Output[] grad1 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); - assertTrue(grad1[0].op().name().startsWith("more_gradients/")); + Output[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("gradients_1/")); + + Output[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad2[0].op().name().startsWith("more_gradients/")); + + Output[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad3[0].op().name().startsWith("even_more_gradients/")); + + try { + g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + } catch (IllegalArgumentException e) { + // expected exception + } } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java index 2fb2c1df48..0e9c7df697 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java @@ -17,7 +17,6 @@ package org.tensorflow.op; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; import java.util.HashMap; @@ -183,20 +182,6 @@ public class ScopeTest { assertEquals(21704, result.intValue()); } } - - @Test - public void prefix() { - try (Graph g = new Graph()) { - Scope s = new Scope(g); - assertNull(s.prefix()); - - Scope sub1 = s.withSubScope("sub1"); - assertEquals("sub1", sub1.prefix()); - - Scope sub2 = sub1.withSubScope("sub2"); - assertEquals("sub1/sub2", sub2.prefix()); - } - } // "handwritten" sample operator classes private static final class Const { diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java index 2ffc69c209..b75f79a421 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -108,17 +108,18 @@ public class GradientsTest { } @Test - public void createGradientsWithScopeName() { + public void validateGradientsNames() { try (Graph g = new Graph()) { - Scope scope = new Scope(g); + Scope scope = new Scope(g).withSubScope("sub"); Output x = TestUtil.placeholder(g, "x1", Float.class); Output y = TestUtil.square(g, "y", x); - Scope gradScope = scope.withSubScope("grads").withSubScope("test"); - Gradients grads = Gradients.create(gradScope, y, Arrays.asList(x)); + Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x)); + assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/")); - assertTrue(grads.dy(0).op().name().startsWith("grads/test/")); + Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x)); + assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/")); } } } -- cgit v1.2.3