aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-06-27 23:07:57 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-06-27 23:07:57 -0400
commitb7baff70bbdc2c785bda47c9eb06584ae46fd3b3 (patch)
tree747666add1b3ea1cb0bbc8a8acb220b7638e51ea /tensorflow/java
parent52e32a7b0ea35b52ec3a9ea5d522a08719f26068 (diff)
Improve unit tests after TF_AddGradients fix
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java52
1 files changed, 36 insertions, 16 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index 3ffc249185..c2e52c22c6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -139,13 +139,19 @@ public class GraphTest {
Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x1);
- Output<Float> y1 = TestUtil.addN(g, y0, x2);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+ Output<Float> y2 = TestUtil.addN(g, y0, x2);
- Output<?>[] grads = g.addGradients(y1, toArray(x1, x2));
- assertNotNull(grads);
- assertEquals(2, grads.length);
- assertEquals(DataType.FLOAT, grads[0].dataType());
- assertEquals(DataType.FLOAT, grads[1].dataType());
+ Output<?>[] grads0 = g.addGradients(y1, toArray(x1));
+ assertNotNull(grads0);
+ assertEquals(1, grads0.length);
+ assertEquals(DataType.FLOAT, grads0[0].dataType());
+
+ Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2));
+ assertNotNull(grads1);
+ assertEquals(2, grads1.length);
+ assertEquals(DataType.FLOAT, grads1[0].dataType());
+ assertEquals(DataType.FLOAT, grads1[1].dataType());
try (Tensor<Float> c1 = Tensors.create(3.0f);
Tensor<Float> c2 = Tensors.create(2.0f);
@@ -153,12 +159,15 @@ public class GraphTest {
s.runner()
.feed(x1, c1)
.feed(x2, c2)
- .fetch(grads[0])
- .fetch(grads[1])
+ .fetch(grads0[0])
+ .fetch(grads1[0])
+ .fetch(grads1[1])
.run())) {
- assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f);
- assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f);
+ assertEquals(3, outputs.size());
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f);
+ assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f);
}
}
}
@@ -172,12 +181,15 @@ public class GraphTest {
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
- Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null);
+ Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+ assertNotNull(grad);
+ assertEquals(1, grad.length);
+ assertEquals(DataType.FLOAT, grad[0].dataType());
try (Tensor<Float> c = Tensors.create(3.0f);
Tensor<?> output = s.runner()
.feed(x, c)
- .fetch(grads[0])
+ .fetch(grad[0])
.run()
.get(0)) {
@@ -192,15 +204,23 @@ public class GraphTest {
Session s = new Session(g)) {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
- Output<Float> y = TestUtil.square(g, "y", x);
- Output<Float> dx = TestUtil.constant(g, "dx", 18.0f);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
- Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx));
+ Output<?>[] grad0 = g.addGradients(y1, toArray(y0));
+ assertNotNull(grad0);
+ assertEquals(1, grad0.length);
+ assertEquals(DataType.FLOAT, grad0[0].dataType());
+
+ Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ assertNotNull(grad1);
+ assertEquals(1, grad1.length);
+ assertEquals(DataType.FLOAT, grad1[0].dataType());
try (Tensor<Float> c = Tensors.create(3.0f);
Tensor<?> output = s.runner()
.feed(x, c)
- .fetch(grads[0])
+ .fetch(grad1[0])
.run()
.get(0)) {