aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-07-06 23:36:13 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-07-25 21:10:29 -0400
commitab063cd57d7eda73bcbaf11d43f8b2e6708979a3 (patch)
treeba1a613840f411f9e7e8721de00161dfba7da3aa /tensorflow/java
parent2b303fddafec6b96a6868aaa76f55cc392b96586 (diff)
Add unit tests for Gradients
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java2
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Scope.java2
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java12
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java5
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java124
7 files changed, 148 insertions, 12 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 73e210fae0..7ceba3903d 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -292,6 +292,19 @@ tf_java_test(
],
)
+tf_java_test(
+ name = "GradientsTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.op.core.GradientsTest",
+ deps = [
+ ":tensorflow",
+ ":testutil",
+ "@junit",
+ ],
+)
+
filegroup(
name = "processor_test_resources",
srcs = glob([
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
index 92e05d2d6d..95a2a2f9f5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
@@ -57,7 +57,7 @@ final class NameScope {
return fullyQualify(makeUnique(actualName));
}
- String prefix() {
+ String opPrefix() {
return opPrefix;
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
index d1ab44c3b2..51a6ce8318 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
@@ -165,7 +165,7 @@ public final class Scope {
* }</pre>
*/
public String prefix() {
- return nameScope.prefix();
+ return nameScope.opPrefix();
}
private Scope(Graph graph, NameScope nameScope) {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index d88dc3ba46..6d71ddfff0 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -59,12 +59,12 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return this option builder
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public Options dx(Iterable<? extends Operand<?>> dx) {
this.dx = dx;
return this;
}
- private Iterable<Operand<?>> dx;
+ private Iterable<? extends Operand<?>> dx;
private Options() {
}
@@ -79,7 +79,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @param options carries optional attributes values
* @return a new instance of {@code Gradients}
*/
- public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(Scope scope, Iterable<? extends Operand<?>> y, Iterable<? extends Operand<?>> x, Options... options) {
Output<?>[] dx = null;
if (options != null) {
for (Options opts : options) {
@@ -105,7 +105,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @return a new instance of {@code Gradients}
*/
@SuppressWarnings({"unchecked", "rawtypes"})
- public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) {
return create(scope, (Iterable) Arrays.asList(y), x, options);
}
@@ -113,7 +113,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return builder to add more options to this operation
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public static Options dx(Iterable<? extends Operand<?>> dx) {
return new Options().dx(dx);
}
@@ -135,7 +135,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* <p>
* Warning: Does not check that the type of the tensor matches T. It is recommended to call
* this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
- * gradients.<Integer>dy(0)}
+ * gradients.<Float>dy(0)}
*
* @param <T> The expected element type of the tensors produced by this output.
* @param index The index of the output among the gradients added by this operation
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 4e84886416..f984c508ee 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -24,7 +24,7 @@ public class TestUtil {
public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
+ public AutoCloseableList(Collection<? extends E> c) {
super(c);
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
index 2057007499..2fb2c1df48 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
@@ -17,7 +17,7 @@ package org.tensorflow.op;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.fail;
import java.util.HashMap;
@@ -188,8 +188,7 @@ public class ScopeTest {
public void prefix() {
try (Graph g = new Graph()) {
Scope s = new Scope(g);
- assertNotNull(s.prefix());
- assertTrue(s.prefix().isEmpty());
+ assertNull(s.prefix());
Scope sub1 = s.withSubScope("sub1");
assertEquals("sub1", sub1.prefix());
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
new file mode 100644
index 0000000000..2ffc69c209
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
@@ -0,0 +1,124 @@
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.Graph;
+import org.tensorflow.Output;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.Tensors;
+import org.tensorflow.TestUtil;
+import org.tensorflow.op.Scope;
+
+@RunWith(JUnit4.class)
+public class GradientsTest {
+
+ @Test
+ public void createGradients() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(2, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ sess.runner()
+ .feed(x, c)
+ .fetch(grads.dy(0))
+ .fetch(grads.dy(1))
+ .run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithSum() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(1, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ sess.runner()
+ .feed(x, c)
+ .fetch(grads.dy(0))
+ .run())) {
+
+ assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithInitialValues() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0));
+ Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy()));
+
+ assertNotNull(grads1);
+ assertNotNull(grads1.dy());
+ assertEquals(1, grads1.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ sess.runner()
+ .feed(x, c)
+ .fetch(grads1.dy(0))
+ .run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithScopeName() {
+ try (Graph g = new Graph()) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y = TestUtil.square(g, "y", x);
+
+ Scope gradScope = scope.withSubScope("grads").withSubScope("test");
+ Gradients grads = Gradients.create(gradScope, y, Arrays.asList(x));
+
+ assertTrue(grads.dy(0).op().name().startsWith("grads/test/"));
+ }
+ }
+}