From ab063cd57d7eda73bcbaf11d43f8b2e6708979a3 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Fri, 6 Jul 2018 23:36:13 -0400 Subject: Add unit tests for Gradients --- tensorflow/java/BUILD | 13 +++ .../src/main/java/org/tensorflow/op/NameScope.java | 2 +- .../src/main/java/org/tensorflow/op/Scope.java | 2 +- .../java/org/tensorflow/op/core/Gradients.java | 12 +- .../src/test/java/org/tensorflow/TestUtil.java | 2 +- .../src/test/java/org/tensorflow/op/ScopeTest.java | 5 +- .../java/org/tensorflow/op/core/GradientsTest.java | 124 +++++++++++++++++++++ 7 files changed, 148 insertions(+), 12 deletions(-) create mode 100644 tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java (limited to 'tensorflow/java') diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 73e210fae0..7ceba3903d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -292,6 +292,19 @@ tf_java_test( ], ) +tf_java_test( + name = "GradientsTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.GradientsTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "processor_test_resources", srcs = glob([ diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java index 92e05d2d6d..95a2a2f9f5 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java @@ -57,7 +57,7 @@ final class NameScope { return fullyQualify(makeUnique(actualName)); } - String prefix() { + String opPrefix() { return opPrefix; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index d1ab44c3b2..51a6ce8318 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -165,7 +165,7 @@ public final class Scope { * } */ public String prefix() { - return nameScope.prefix(); + return nameScope.opPrefix(); } private Scope(Graph graph, NameScope nameScope) { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index d88dc3ba46..6d71ddfff0 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -59,12 +59,12 @@ public class Gradients implements Op, Iterable> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return this option builder */ - public Options dx(Iterable> dx) { + public Options dx(Iterable> dx) { this.dx = dx; return this; } - private Iterable> dx; + private Iterable> dx; private Options() { } @@ -79,7 +79,7 @@ public class Gradients implements Op, Iterable> { * @param options carries optional attributes values * @return a new instance of {@code Gradients} */ - public static Gradients create(Scope scope, Iterable> y, Iterable> x, Options... options) { + public static Gradients create(Scope scope, Iterable> y, Iterable> x, Options... options) { Output[] dx = null; if (options != null) { for (Options opts : options) { @@ -105,7 +105,7 @@ public class Gradients implements Op, Iterable> { * @return a new instance of {@code Gradients} */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Gradients create(Scope scope, Operand y, Iterable> x, Options... options) { + public static Gradients create(Scope scope, Operand y, Iterable> x, Options... options) { return create(scope, (Iterable) Arrays.asList(y), x, options); } @@ -113,7 +113,7 @@ public class Gradients implements Op, Iterable> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return builder to add more options to this operation */ - public Options dx(Iterable> dx) { + public static Options dx(Iterable> dx) { return new Options().dx(dx); } @@ -135,7 +135,7 @@ public class Gradients implements Op, Iterable> { *

* Warning: Does not check that the type of the tensor matches T. It is recommended to call * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code - * gradients.dy(0)} + * gradients.dy(0)} * * @param The expected element type of the tensors produced by this output. * @param index The index of the output among the gradients added by this operation diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 4e84886416..f984c508ee 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -24,7 +24,7 @@ public class TestUtil { public static final class AutoCloseableList extends ArrayList implements AutoCloseable { - AutoCloseableList(Collection c) { + public AutoCloseableList(Collection c) { super(c); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java index 2057007499..2fb2c1df48 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java @@ -17,7 +17,7 @@ package org.tensorflow.op; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; import java.util.HashMap; @@ -188,8 +188,7 @@ public class ScopeTest { public void prefix() { try (Graph g = new Graph()) { Scope s = new Scope(g); - assertNotNull(s.prefix()); - assertTrue(s.prefix().isEmpty()); + assertNull(s.prefix()); Scope sub1 = s.withSubScope("sub1"); assertEquals("sub1", sub1.prefix()); diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java new file mode 100644 index 0000000000..2ffc69c209 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -0,0 +1,124 @@ +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.Tensors; +import org.tensorflow.TestUtil; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class GradientsTest { + + @Test + public void createGradients() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(2, grads.dy().size()); + + try (Tensor c = Tensors.create(3.0f); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList<>( + sess.runner() + .feed(x, c) + .fetch(grads.dy(0)) + .fetch(grads.dy(1)) + .run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithSum() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (Tensor c = Tensors.create(3.0f); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList<>( + sess.runner() + .feed(x, c) + .fetch(grads.dy(0)) + .run())) { + + assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithInitialValues() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0)); + Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + + assertNotNull(grads1); + assertNotNull(grads1.dy()); + assertEquals(1, grads1.dy().size()); + + try (Tensor c = Tensors.create(3.0f); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList<>( + sess.runner() + .feed(x, c) + .fetch(grads1.dy(0)) + .run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithScopeName() { + try (Graph g = new Graph()) { + Scope scope = new Scope(g); + + Output x = TestUtil.placeholder(g, "x1", Float.class); + Output y = TestUtil.square(g, "y", x); + + Scope gradScope = scope.withSubScope("grads").withSubScope("test"); + Gradients grads = Gradients.create(gradScope, y, Arrays.asList(x)); + + assertTrue(grads.dy(0).op().name().startsWith("grads/test/")); + } + } +} -- cgit v1.2.3