diff options
author | 2018-07-06 23:36:13 -0400 | |
---|---|---|
committer | 2018-07-25 21:10:29 -0400 | |
commit | ab063cd57d7eda73bcbaf11d43f8b2e6708979a3 (patch) | |
tree | ba1a613840f411f9e7e8721de00161dfba7da3aa /tensorflow/java | |
parent | 2b303fddafec6b96a6868aaa76f55cc392b96586 (diff) |
Add unit tests for Gradients
Diffstat (limited to 'tensorflow/java')
7 files changed, 148 insertions, 12 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 73e210fae0..7ceba3903d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -292,6 +292,19 @@ tf_java_test( ], ) +tf_java_test( + name = "GradientsTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.GradientsTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "processor_test_resources", srcs = glob([ diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java index 92e05d2d6d..95a2a2f9f5 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java @@ -57,7 +57,7 @@ final class NameScope { return fullyQualify(makeUnique(actualName)); } - String prefix() { + String opPrefix() { return opPrefix; } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index d1ab44c3b2..51a6ce8318 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -165,7 +165,7 @@ public final class Scope { * }</pre> */ public String prefix() { - return nameScope.prefix(); + return nameScope.opPrefix(); } private Scope(Graph graph, NameScope nameScope) { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index d88dc3ba46..6d71ddfff0 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -59,12 +59,12 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return this option builder */ - public Options dx(Iterable<Operand<?>> dx) { + public Options dx(Iterable<? extends Operand<?>> dx) { this.dx = dx; return this; } - private Iterable<Operand<?>> dx; + private Iterable<? extends Operand<?>> dx; private Options() { } @@ -79,7 +79,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @param options carries optional attributes values * @return a new instance of {@code Gradients} */ - public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create(Scope scope, Iterable<? extends Operand<?>> y, Iterable<? extends Operand<?>> x, Options... options) { Output<?>[] dx = null; if (options != null) { for (Options opts : options) { @@ -105,7 +105,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @return a new instance of {@code Gradients} */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create(Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) { return create(scope, (Iterable) Arrays.asList(y), x, options); } @@ -113,7 +113,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return builder to add more options to this operation */ - public Options dx(Iterable<Operand<?>> dx) { + public static Options dx(Iterable<? extends Operand<?>> dx) { return new Options().dx(dx); } @@ -135,7 +135,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * <p> * Warning: Does not check that the type of the tensor matches T. It is recommended to call * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code - * gradients.<Integer>dy(0)} + * gradients.<Float>dy(0)} * * @param <T> The expected element type of the tensors produced by this output. * @param index The index of the output among the gradients added by this operation diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 4e84886416..f984c508ee 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -24,7 +24,7 @@ public class TestUtil { public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> implements AutoCloseable { - AutoCloseableList(Collection<? extends E> c) { + public AutoCloseableList(Collection<? extends E> c) { super(c); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java index 2057007499..2fb2c1df48 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java @@ -17,7 +17,7 @@ package org.tensorflow.op; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertNull; import static org.junit.Assert.fail; import java.util.HashMap; @@ -188,8 +188,7 @@ public class ScopeTest { public void prefix() { try (Graph g = new Graph()) { Scope s = new Scope(g); - assertNotNull(s.prefix()); - assertTrue(s.prefix().isEmpty()); + assertNull(s.prefix()); Scope sub1 = s.withSubScope("sub1"); assertEquals("sub1", sub1.prefix()); diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java new file mode 100644 index 0000000000..2ffc69c209 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -0,0 +1,124 @@ +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.Tensors; +import org.tensorflow.TestUtil; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class GradientsTest { + + @Test + public void createGradients() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(2, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>( + sess.runner() + .feed(x, c) + .fetch(grads.dy(0)) + .fetch(grads.dy(1)) + .run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithSum() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>( + sess.runner() + .feed(x, c) + .fetch(grads.dy(0)) + .run())) { + + assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithInitialValues() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0)); + Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + + assertNotNull(grads1); + assertNotNull(grads1.dy()); + assertEquals(1, grads1.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>( + sess.runner() + .feed(x, c) + .fetch(grads1.dy(0)) + .run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithScopeName() { + try (Graph g = new Graph()) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y = TestUtil.square(g, "y", x); + + Scope gradScope = scope.withSubScope("grads").withSubScope("test"); + Gradients grads = Gradients.create(gradScope, y, Arrays.asList(x)); + + assertTrue(grads.dy(0).op().name().startsWith("grads/test/")); + } + } +} |