From b7baff70bbdc2c785bda47c9eb06584ae46fd3b3 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Wed, 27 Jun 2018 23:07:57 -0400 Subject: Improve unit tests after TF_AddGradients fix --- .../src/test/java/org/tensorflow/GraphTest.java | 52 +++++++++++++++------- 1 file changed, 36 insertions(+), 16 deletions(-) (limited to 'tensorflow/java') diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index 3ffc249185..c2e52c22c6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -139,13 +139,19 @@ public class GraphTest { Output x1 = TestUtil.placeholder(g, "x1", Float.class); Output x2 = TestUtil.placeholder(g, "x2", Float.class); Output y0 = TestUtil.square(g, "y0", x1); - Output y1 = TestUtil.addN(g, y0, x2); + Output y1 = TestUtil.square(g, "y1", y0); + Output y2 = TestUtil.addN(g, y0, x2); - Output[] grads = g.addGradients(y1, toArray(x1, x2)); - assertNotNull(grads); - assertEquals(2, grads.length); - assertEquals(DataType.FLOAT, grads[0].dataType()); - assertEquals(DataType.FLOAT, grads[1].dataType()); + Output[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); try (Tensor c1 = Tensors.create(3.0f); Tensor c2 = Tensors.create(2.0f); @@ -153,12 +159,15 @@ public class GraphTest { s.runner() .feed(x1, c1) .feed(x2, c2) - .fetch(grads[0]) - .fetch(grads[1]) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) .run())) { - assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f); - assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); } } } @@ -172,12 +181,15 @@ public class GraphTest { Output y0 = TestUtil.square(g, "y0", x); Output y1 = TestUtil.square(g, "y1", y0); - Output[] grads = g.addGradients(toArray(y0, y1), toArray(x), null); + Output[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); try (Tensor c = Tensors.create(3.0f); Tensor output = s.runner() .feed(x, c) - .fetch(grads[0]) + .fetch(grad[0]) .run() .get(0)) { @@ -192,15 +204,23 @@ public class GraphTest { Session s = new Session(g)) { Output x = TestUtil.placeholder(g, "x", Float.class); - Output y = TestUtil.square(g, "y", x); - Output dx = TestUtil.constant(g, "dx", 18.0f); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); - Output[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx)); + Output[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); try (Tensor c = Tensors.create(3.0f); Tensor output = s.runner() .feed(x, c) - .fetch(grads[0]) + .fetch(grad1[0]) .run() .get(0)) { -- cgit v1.2.3