diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-06-27 23:07:57 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-06-27 23:07:57 -0400 |
commit | b7baff70bbdc2c785bda47c9eb06584ae46fd3b3 (patch) | |
tree | 747666add1b3ea1cb0bbc8a8acb220b7638e51ea /tensorflow/java | |
parent | 52e32a7b0ea35b52ec3a9ea5d522a08719f26068 (diff) |
Improve unit tests after TF_AddGradients fix
Diffstat (limited to 'tensorflow/java')
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 52 |
1 files changed, 36 insertions, 16 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index 3ffc249185..c2e52c22c6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -139,13 +139,19 @@ public class GraphTest { Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class); Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class); Output<Float> y0 = TestUtil.square(g, "y0", x1); - Output<Float> y1 = TestUtil.addN(g, y0, x2); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + Output<Float> y2 = TestUtil.addN(g, y0, x2); - Output<?>[] grads = g.addGradients(y1, toArray(x1, x2)); - assertNotNull(grads); - assertEquals(2, grads.length); - assertEquals(DataType.FLOAT, grads[0].dataType()); - assertEquals(DataType.FLOAT, grads[1].dataType()); + Output<?>[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); try (Tensor<Float> c1 = Tensors.create(3.0f); Tensor<Float> c2 = Tensors.create(2.0f); @@ -153,12 +159,15 @@ public class GraphTest { s.runner() .feed(x1, c1) .feed(x2, c2) - .fetch(grads[0]) - .fetch(grads[1]) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) .run())) { - assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f); - assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); } } } @@ -172,12 +181,15 @@ public class GraphTest { Output<Float> y0 = TestUtil.square(g, "y0", x); Output<Float> y1 = TestUtil.square(g, "y1", y0); - Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null); + Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); try (Tensor<Float> c = Tensors.create(3.0f); Tensor<?> output = s.runner() .feed(x, c) - .fetch(grads[0]) + .fetch(grad[0]) .run() .get(0)) { @@ -192,15 +204,23 @@ public class GraphTest { Session s = new Session(g)) { Output<Float> x = TestUtil.placeholder(g, "x", Float.class); - Output<Float> y = TestUtil.square(g, "y", x); - Output<Float> dx = TestUtil.constant(g, "dx", 18.0f); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); - Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx)); + Output<?>[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); try (Tensor<Float> c = Tensors.create(3.0f); Tensor<?> output = s.runner() .feed(x, c) - .fetch(grads[0]) + .fetch(grad1[0]) .run() .get(0)) { |