diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-06-21 22:14:15 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-06-27 21:47:59 -0400 |
commit | 9b7d92dbad4a18df0c34ff425a1e236f1dd75817 (patch) | |
tree | ec4439362ab61338d05b6169a3381dfdd90842e2 /tensorflow/java | |
parent | fac56f9c9ab58fe7406a826683559de4cef85637 (diff) |
First code review
Diffstat (limited to 'tensorflow/java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Graph.java | 92 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java (renamed from tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java) | 40 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 83 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/TestUtil.java | 9 |
4 files changed, 157 insertions, 67 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 92ab4ef4d7..7d19696749 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -147,19 +147,19 @@ public final class Graph implements AutoCloseable { * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} * <p> - * {@code dx} are used as initial gradients (which represent the symbolic partial - * derivatives of some loss function {@code L} w.r.t. {@code y}). - * {@code dx} must be null or have size of {@code y}. + * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function + * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. * <p> - * If {@code dx} is null, the implementation will use dx of {@code OnesLike} for all + * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all * shapes in {@code y}. * - * @param y - * @param x - * @param dx + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + Output<?>[] dy = new Output<?>[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; final long[] xHandles = new long[x.length]; @@ -167,43 +167,57 @@ public final class Graph implements AutoCloseable { long[] dxHandles = null; int[] dxIndices = null; - for (int i = 0; i < y.length; ++i) { - yHandles[i] = y[i].op().getUnsafeNativeHandle(); - yIndices[i] = y[i].index(); - } - for (int i = 0; i < x.length; ++i) { - xHandles[i] = x[i].op().getUnsafeNativeHandle(); - xIndices[i] = x[i].index(); - } - if (dx != null && dx.length > 0) { - dxHandles = new long[dx.length]; - dxIndices = new int[dx.length]; + try (Reference ref = ref()) { + for (int i = 0; i < y.length; ++i) { + yHandles[i] = y[i].op().getUnsafeNativeHandle(); + yIndices[i] = y[i].index(); + } + for (int i = 0; i < x.length; ++i) { + xHandles[i] = x[i].op().getUnsafeNativeHandle(); + xIndices[i] = x[i].index(); + } + if (dx != null && dx.length > 0) { + dxHandles = new long[dx.length]; + dxIndices = new int[dx.length]; - for (int i = 0; i < dx.length; ++i) { - dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); - dxIndices[i] = dx[i].index(); + for (int i = 0; i < dx.length; ++i) { + dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); + dxIndices[i] = dx[i].index(); + } + } + // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles + // of the gradient operations while the second holds the index of their output + // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] + long[] dyHandlesAndIndices = + addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + int ndy = dyHandlesAndIndices.length >> 1; + if (ndy != dy.length) { + throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length + + " were expected"); + } + for (int i = 0, j = ndy; i < ndy; ++i, ++j) { + Operation op = new Operation(this, dyHandlesAndIndices[i]); + dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); } - } - // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles - // of the gradient operations while the second holds the index of their output - // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain - // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] - long[] dyHandlesAndIndices; - synchronized (nativeHandleLock) { - dyHandlesAndIndices = addGradients(nativeHandle, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); - } - int ndy = dyHandlesAndIndices.length >> 1; - if (ndy != x.length) { - throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + x.length - + " were expected"); - } - Output<?>[] dy = new Output<?>[ndy]; - for (int i = 0, j = ndy; i < ndy; ++i, ++j) { - Operation op = new Operation(this, dyHandlesAndIndices[i]); - dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); } return dy; } + + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code dy/dx_1, dy/dx_2...} + * <p> + * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is + * a single output and {@code dx} is null. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { + return addGradients(new Output<?>[]{y}, x, null); + } private final Object nativeHandleLock = new Object(); private long nativeHandle; diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java index 2db34bf188..097b541501 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java @@ -40,7 +40,7 @@ import org.tensorflow.op.annotation.Operator; * <p> * Example of usage: * <pre>{@code - * AddGradients gradients = AddGradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b)); + * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b)); * * Constant<Float> alpha = ops.constant(1.0f, Float.class); * ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0)); @@ -48,10 +48,10 @@ import org.tensorflow.op.annotation.Operator; * }</pre> */ @Operator -public class AddGradients implements Op, Iterable<Operand<?>> { +public class Gradients implements Op, Iterable<Operand<?>> { /** - * Optional attributes for {@link AddGradients} + * Optional attributes for {@link Gradients} */ public static class Options { @@ -74,12 +74,12 @@ public class AddGradients implements Op, Iterable<Operand<?>> { * Adds gradients computation ops to the graph according to scope. * * @param scope current graph scope - * @param y - * @param x - * @param dx - * @return a new instance of {@code AddGradients} + * @param y outputs of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} */ - public static AddGradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { Output<?>[] dx = null; if (options != null) { for (Options opts : options) { @@ -89,7 +89,24 @@ public class AddGradients implements Op, Iterable<Operand<?>> { } } Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); - return new AddGradients(Arrays.asList(gradOutputs)); + return new Gradients(Arrays.asList(gradOutputs)); + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is + * a single output. + * + * @param scope current graph scope + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) { + return create(scope, (Iterable) Arrays.asList(y), x, options); } /** @@ -107,8 +124,7 @@ public class AddGradients implements Op, Iterable<Operand<?>> { } /** - * {@code dy} of size {@code x}, i.e. the outputs of the operations added to the graph to compute gradients for each - * {@code x} nodes respectively. + * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x} */ public List<Output<?>> dy() { return dy; @@ -131,7 +147,7 @@ public class AddGradients implements Op, Iterable<Operand<?>> { private List<Output<?>> dy; - private AddGradients(List<Output<?>> dy) { + private Gradients(List<Output<?>> dy) { this.dy = dy; } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index aa6e5f0235..ac867f1e46 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -130,24 +131,76 @@ public class GraphTest { // expected exception. } } - + @Test - public void addGradientsComputationOpsToGraph() { - try (Graph g = new Graph()) { - Output<Integer> a = TestUtil.constant(g, "A", new int[][] {{1},{2}}); - Output<Integer> b = TestUtil.placeholder(g, "B", Integer.class); - Output<Integer> c = TestUtil.placeholder(g, "C", Integer.class); - Output<Integer> ab = TestUtil.matmul(g, "AxB", a, b, false, false); - Output<Integer> abc = TestUtil.matmul(g, "AxBxC", ab, c, false, false); + public void addGradientsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x1); + Output<Float> y1 = TestUtil.addN(g, y0, x2); + + Output<?>[] grads = g.addGradients(y1, toArray(x1, x2)); + assertNotNull(grads); + assertEquals(2, grads.length); + assertEquals(DataType.FLOAT, grads[0].dataType()); + assertEquals(DataType.FLOAT, grads[1].dataType()); + + List<Tensor<?>> outputs = s.runner() + .feed(x1, Tensors.create(3.0f)) + .feed(x2, Tensors.create(2.0f)) + .fetch(grads[0]) + .fetch(grads[1]) + .run(); + + assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f); + } + } + + @Test + public void addGradientSumsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); - Output<?>[] grad = g.addGradients(new Output<?>[] {abc}, new Output<?>[] {b, c}, null); + Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null); + + List<Tensor<?>> outputs = s.runner() + .feed(x, Tensors.create(3.0f)) + .fetch(grads[0]) + .run(); + + assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + } + } + + @Test + public void addGradientsWithInitialValuesToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y = TestUtil.square(g, "y", x); + Output<Float> dx = TestUtil.constant(g, "dx", 18.0f); - assertNotNull(grad); - assertEquals(2, grad.length); - assertNotNull(grad[0]); - assertEquals(DataType.INT32, grad[0].dataType()); - assertNotNull(grad[1]); - assertEquals(DataType.INT32, grad[1].dataType()); + Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx)); + + List<Tensor<?>> outputs = s.runner() + .feed(x, Tensors.create(3.0f)) + .fetch(grads[0]) + .run(); + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); } } + + private static Output<?>[] toArray(Output<?>... outputs) { + return outputs; + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index c973b5a3d8..7feb296aed 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -36,7 +36,7 @@ public class TestUtil { .<T>output(0); } - public static Output<?> addN(Graph g, Output<?>... inputs) { + public static <T> Output<T> addN(Graph g, Output<?>... inputs) { return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } @@ -58,6 +58,13 @@ public class TestUtil { .setAttr("num_split", numSplit) .build(); } + + public static <T> Output<T> square(Graph g, String name, Output<T> value) { + return g.opBuilder("Square", name) + .addInput(value) + .build() + .<T>output(0); + } public static void transpose_A_times_X(Graph g, int[][] a) { Output<Integer> aa = constant(g, "A", a); |