diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Graph.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Graph.java | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 7d19696749..f2bd3e99a5 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -153,12 +153,13 @@ public final class Graph implements AutoCloseable { * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all * shapes in {@code y}. * + * @param scopeName name of the subscope into which gradients operations are added. If null, defaults to "gradients". * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ - public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + public Output<?>[] addGradients(String scopeName, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { Output<?>[] dy = new Output<?>[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; @@ -185,12 +186,12 @@ public final class Graph implements AutoCloseable { dxIndices[i] = dx[i].index(); } } - // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles - // of the gradient operations while the second holds the index of their output - // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native + // handles of the gradient operations while the second holds the index of their output e.g. given + // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] long[] dyHandlesAndIndices = - addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + addGradients(ref.nativeHandle(), scopeName, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length @@ -209,14 +210,14 @@ public final class Graph implements AutoCloseable { * i.e., {@code dy/dx_1, dy/dx_2...} * <p> * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is - * a single output and {@code dx} is null. + * a single output, {@code dx} is null and {@code scopeName} is null. * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { - return addGradients(new Output<?>[]{y}, x, null); + return addGradients(null, new Output<?>[]{y}, x, null); } private final Object nativeHandleLock = new Object(); @@ -330,7 +331,7 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); - private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, + private static native long[] addGradients(long handle, String scopeName, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); static { |