aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Graph.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Graph.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index f2bd3e99a5..353092701b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -153,13 +153,14 @@ public final class Graph implements AutoCloseable {
* If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
* shapes in {@code y}.
*
- * @param scopeName name of the subscope into which gradients operations are added. If null, defaults to "gradients".
+ * @param prefix string prefix applied to names of nodes added to the graph to compute gradients.
+ * If null, defaults to "gradients".
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
- public Output<?>[] addGradients(String scopeName, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
@@ -191,7 +192,7 @@ public final class Graph implements AutoCloseable {
// xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), scopeName, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(ref.nativeHandle(), prefix, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -210,7 +211,7 @@ public final class Graph implements AutoCloseable {
* i.e., {@code dy/dx_1, dy/dx_2...}
* <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
- * a single output, {@code dx} is null and {@code scopeName} is null.
+ * a single output, {@code dx} is null and {@code prefix} is null.
*
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
@@ -331,7 +332,7 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, String scopeName, long[] inputHandles, int[] inputIndices,
+ private static native long[] addGradients(long handle, String prefix, long[] inputHandles, int[] inputIndices,
long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
static {