aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-06-21 22:14:15 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-06-27 21:47:59 -0400
commit9b7d92dbad4a18df0c34ff425a1e236f1dd75817 (patch)
treeec4439362ab61338d05b6169a3381dfdd90842e2 /tensorflow/java
parentfac56f9c9ab58fe7406a826683559de4cef85637 (diff)
First code review
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java92
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java (renamed from tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java)40
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java83
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java9
4 files changed, 157 insertions, 67 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 92ab4ef4d7..7d19696749 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -147,19 +147,19 @@ public final class Graph implements AutoCloseable {
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
* i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
* <p>
- * {@code dx} are used as initial gradients (which represent the symbolic partial
- * derivatives of some loss function {@code L} w.r.t. {@code y}).
- * {@code dx} must be null or have size of {@code y}.
+ * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
+ * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
* <p>
- * If {@code dx} is null, the implementation will use dx of {@code OnesLike} for all
+ * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
* shapes in {@code y}.
*
- * @param y
- * @param x
- * @param dx
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
final long[] xHandles = new long[x.length];
@@ -167,43 +167,57 @@ public final class Graph implements AutoCloseable {
long[] dxHandles = null;
int[] dxIndices = null;
- for (int i = 0; i < y.length; ++i) {
- yHandles[i] = y[i].op().getUnsafeNativeHandle();
- yIndices[i] = y[i].index();
- }
- for (int i = 0; i < x.length; ++i) {
- xHandles[i] = x[i].op().getUnsafeNativeHandle();
- xIndices[i] = x[i].index();
- }
- if (dx != null && dx.length > 0) {
- dxHandles = new long[dx.length];
- dxIndices = new int[dx.length];
+ try (Reference ref = ref()) {
+ for (int i = 0; i < y.length; ++i) {
+ yHandles[i] = y[i].op().getUnsafeNativeHandle();
+ yIndices[i] = y[i].index();
+ }
+ for (int i = 0; i < x.length; ++i) {
+ xHandles[i] = x[i].op().getUnsafeNativeHandle();
+ xIndices[i] = x[i].index();
+ }
+ if (dx != null && dx.length > 0) {
+ dxHandles = new long[dx.length];
+ dxIndices = new int[dx.length];
- for (int i = 0; i < dx.length; ++i) {
- dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
- dxIndices[i] = dx[i].index();
+ for (int i = 0; i < dx.length; ++i) {
+ dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
+ dxIndices[i] = dx[i].index();
+ }
+ }
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
+ // of the gradient operations while the second holds the index of their output
+ // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
+ long[] dyHandlesAndIndices =
+ addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ int ndy = dyHandlesAndIndices.length >> 1;
+ if (ndy != dy.length) {
+ throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
+ + " were expected");
+ }
+ for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
+ Operation op = new Operation(this, dyHandlesAndIndices[i]);
+ dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
}
- }
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
- // of the gradient operations while the second holds the index of their output
- // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
- // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
- long[] dyHandlesAndIndices;
- synchronized (nativeHandleLock) {
- dyHandlesAndIndices = addGradients(nativeHandle, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
- }
- int ndy = dyHandlesAndIndices.length >> 1;
- if (ndy != x.length) {
- throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + x.length
- + " were expected");
- }
- Output<?>[] dy = new Output<?>[ndy];
- for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
- Operation op = new Operation(this, dyHandlesAndIndices[i]);
- dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
}
return dy;
}
+
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code dy/dx_1, dy/dx_2...}
+ * <p>
+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
+ * a single output and {@code dx} is null.
+ *
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
+ return addGradients(new Output<?>[]{y}, x, null);
+ }
private final Object nativeHandleLock = new Object();
private long nativeHandle;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java
index 2db34bf188..097b541501 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/training/Gradients.java
@@ -40,7 +40,7 @@ import org.tensorflow.op.annotation.Operator;
* <p>
* Example of usage:
* <pre>{@code
- * AddGradients gradients = AddGradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
*
* Constant<Float> alpha = ops.constant(1.0f, Float.class);
* ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0));
@@ -48,10 +48,10 @@ import org.tensorflow.op.annotation.Operator;
* }</pre>
*/
@Operator
-public class AddGradients implements Op, Iterable<Operand<?>> {
+public class Gradients implements Op, Iterable<Operand<?>> {
/**
- * Optional attributes for {@link AddGradients}
+ * Optional attributes for {@link Gradients}
*/
public static class Options {
@@ -74,12 +74,12 @@ public class AddGradients implements Op, Iterable<Operand<?>> {
* Adds gradients computation ops to the graph according to scope.
*
* @param scope current graph scope
- * @param y
- * @param x
- * @param dx
- * @return a new instance of {@code AddGradients}
+ * @param y outputs of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
*/
- public static AddGradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
Output<?>[] dx = null;
if (options != null) {
for (Options opts : options) {
@@ -89,7 +89,24 @@ public class AddGradients implements Op, Iterable<Operand<?>> {
}
}
Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
- return new AddGradients(Arrays.asList(gradOutputs));
+ return new Gradients(Arrays.asList(gradOutputs));
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
+ * a single output.
+ *
+ * @param scope current graph scope
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
+ */
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ return create(scope, (Iterable) Arrays.asList(y), x, options);
}
/**
@@ -107,8 +124,7 @@ public class AddGradients implements Op, Iterable<Operand<?>> {
}
/**
- * {@code dy} of size {@code x}, i.e. the outputs of the operations added to the graph to compute gradients for each
- * {@code x} nodes respectively.
+ * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x}
*/
public List<Output<?>> dy() {
return dy;
@@ -131,7 +147,7 @@ public class AddGradients implements Op, Iterable<Operand<?>> {
private List<Output<?>> dy;
- private AddGradients(List<Output<?>> dy) {
+ private Gradients(List<Output<?>> dy) {
this.dy = dy;
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index aa6e5f0235..ac867f1e46 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -130,24 +131,76 @@ public class GraphTest {
// expected exception.
}
}
-
+
@Test
- public void addGradientsComputationOpsToGraph() {
- try (Graph g = new Graph()) {
- Output<Integer> a = TestUtil.constant(g, "A", new int[][] {{1},{2}});
- Output<Integer> b = TestUtil.placeholder(g, "B", Integer.class);
- Output<Integer> c = TestUtil.placeholder(g, "C", Integer.class);
- Output<Integer> ab = TestUtil.matmul(g, "AxB", a, b, false, false);
- Output<Integer> abc = TestUtil.matmul(g, "AxBxC", ab, c, false, false);
+ public void addGradientsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x1);
+ Output<Float> y1 = TestUtil.addN(g, y0, x2);
+
+ Output<?>[] grads = g.addGradients(y1, toArray(x1, x2));
+ assertNotNull(grads);
+ assertEquals(2, grads.length);
+ assertEquals(DataType.FLOAT, grads[0].dataType());
+ assertEquals(DataType.FLOAT, grads[1].dataType());
+
+ List<Tensor<?>> outputs = s.runner()
+ .feed(x1, Tensors.create(3.0f))
+ .feed(x2, Tensors.create(2.0f))
+ .fetch(grads[0])
+ .fetch(grads[1])
+ .run();
+
+ assertEquals(6.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(1.0f, outputs.get(1).floatValue(), 0.0f);
+ }
+ }
+
+ @Test
+ public void addGradientSumsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
- Output<?>[] grad = g.addGradients(new Output<?>[] {abc}, new Output<?>[] {b, c}, null);
+ Output<?>[] grads = g.addGradients(toArray(y0, y1), toArray(x), null);
+
+ List<Tensor<?>> outputs = s.runner()
+ .feed(x, Tensors.create(3.0f))
+ .fetch(grads[0])
+ .run();
+
+ assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+
+ @Test
+ public void addGradientsWithInitialValuesToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y = TestUtil.square(g, "y", x);
+ Output<Float> dx = TestUtil.constant(g, "dx", 18.0f);
- assertNotNull(grad);
- assertEquals(2, grad.length);
- assertNotNull(grad[0]);
- assertEquals(DataType.INT32, grad[0].dataType());
- assertNotNull(grad[1]);
- assertEquals(DataType.INT32, grad[1].dataType());
+ Output<?>[] grads = g.addGradients(toArray(y), toArray(x), toArray(dx));
+
+ List<Tensor<?>> outputs = s.runner()
+ .feed(x, Tensors.create(3.0f))
+ .fetch(grads[0])
+ .run();
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
}
}
+
+ private static Output<?>[] toArray(Output<?>... outputs) {
+ return outputs;
+ }
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index c973b5a3d8..7feb296aed 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -36,7 +36,7 @@ public class TestUtil {
.<T>output(0);
}
- public static Output<?> addN(Graph g, Output<?>... inputs) {
+ public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
@@ -58,6 +58,13 @@ public class TestUtil {
.setAttr("num_split", numSplit)
.build();
}
+
+ public static <T> Output<T> square(Graph g, String name, Output<T> value) {
+ return g.opBuilder("Square", name)
+ .addInput(value)
+ .build()
+ .<T>output(0);
+ }
public static void transpose_A_times_X(Graph g, int[][] a) {
Output<Integer> aa = constant(g, "A", a);