aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/Graph.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/Graph.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java64
1 files changed, 44 insertions, 20 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 7d19696749..752b49af04 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable {
}
/**
- * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
- * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
- * <p>
- * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
- * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
- * <p>
- * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
- * shapes in {@code y}.
- *
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
+ * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ *
+ * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives
+ * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
+ * {@code y}.
+ *
+ * <p>If {@code dx} is null, the implementation will use dx of {@link
+ * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
+ *
+ * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute
+ * gradients. It must be unique within the provided graph or the operation will fail.
+ *
+ * <p>If {@code prefix} is null, then one will be chosen automatically.
+ *
+ * @param prefix unique string prefix applied before the names of nodes added to the graph to
+ * compute gradients. If null, a default one will be chosen.
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
- public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
@@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable {
dxIndices[i] = dx[i].index();
}
}
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
- // of the gradient operations while the second holds the index of their output
- // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first
+ // holds the native handles of the gradient operations while the second holds the index of
+ // their output e.g. given
+ // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(
+ ref.nativeHandle(),
+ prefix,
+ yHandles,
+ yIndices,
+ xHandles,
+ xIndices,
+ dxHandles,
+ dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable {
/**
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
* i.e., {@code dy/dx_1, dy/dx_2...}
- * <p>
+ * <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
- * a single output and {@code dx} is null.
- *
+ * a single output, {@code dx} is null and {@code prefix} is null.
+ *
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
- return addGradients(new Output<?>[]{y}, x, null);
+ return addGradients(null, new Output<?>[] {y}, x, null);
}
private final Object nativeHandleLock = new Object();
@@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
- long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+ private static native long[] addGradients(
+ long handle,
+ String prefix,
+ long[] inputHandles,
+ int[] inputIndices,
+ long[] outputHandles,
+ int[] outputIndices,
+ long[] gradInputHandles,
+ int[] gradInputIndices);
static {
TensorFlow.init();