diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Graph.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Graph.java | 64 |
1 files changed, 44 insertions, 20 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 7d19696749..752b49af04 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable { } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, - * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} - * <p> - * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function - * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. - * <p> - * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all - * shapes in {@code y}. - * + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., + * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * + * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives + * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of + * {@code y}. + * + * <p>If {@code dx} is null, the implementation will use dx of {@link + * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. + * + * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute + * gradients. It must be unique within the provided graph or the operation will fail. + * + * <p>If {@code prefix} is null, then one will be chosen automatically. + * + * @param prefix unique string prefix applied before the names of nodes added to the graph to + * compute gradients. If null, a default one will be chosen. * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ - public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { Output<?>[] dy = new Output<?>[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; @@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable { dxIndices[i] = dx[i].index(); } } - // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles - // of the gradient operations while the second holds the index of their output - // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // Gradient outputs are returned in two continuous arrays concatenated into one. The first + // holds the native handles of the gradient operations while the second holds the index of + // their output e.g. given + // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] long[] dyHandlesAndIndices = - addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + addGradients( + ref.nativeHandle(), + prefix, + yHandles, + yIndices, + xHandles, + xIndices, + dxHandles, + dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length @@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable { /** * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, * i.e., {@code dy/dx_1, dy/dx_2...} - * <p> + * <p> * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is - * a single output and {@code dx} is null. - * + * a single output, {@code dx} is null and {@code prefix} is null. + * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { - return addGradients(new Output<?>[]{y}, x, null); + return addGradients(null, new Output<?>[] {y}, x, null); } private final Object nativeHandleLock = new Object(); @@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); - private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, - long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + private static native long[] addGradients( + long handle, + String prefix, + long[] inputHandles, + int[] inputIndices, + long[] outputHandles, + int[] outputIndices, + long[] gradInputHandles, + int[] gradInputIndices); static { TensorFlow.init(); |