diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java | 48 |
1 files changed, 28 insertions, 20 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index f4671c8af9..eea9dc1c47 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.Arrays; import java.util.Iterator; import java.util.List; - import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; @@ -54,32 +53,36 @@ public class Gradients implements Op, Iterable<Operand<?>> { * Optional attributes for {@link Gradients} */ public static class Options { - + /** * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return this option builder */ - public Options dx(Iterable<Operand<?>> dx) { + public Options dx(Iterable<? extends Operand<?>> dx) { this.dx = dx; return this; } - - private Iterable<Operand<?>> dx; - + + private Iterable<? extends Operand<?>> dx; + private Options() { } } /** * Adds gradients computation ops to the graph according to scope. - * + * * @param scope current graph scope * @param y outputs of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param options carries optional attributes values * @return a new instance of {@code Gradients} */ - public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, + Iterable<? extends Operand<?>> y, + Iterable<? extends Operand<?>> x, + Options... options) { Output<?>[] dx = null; if (options != null) { for (Options opts : options) { @@ -88,16 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> { } } } - Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); - return new Gradients(Arrays.asList(gradOutputs)); + Output<?>[] dy = + scope + .graph() + .addGradients( + scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(dy)); } /** * Adds gradients computation ops to the graph according to scope. - * - * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is - * a single output. - * + * + * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where + * {@code y} is a single output. + * * @param scope current graph scope * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed @@ -105,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @return a new instance of {@code Gradients} */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) { return create(scope, (Iterable) Arrays.asList(y), x, options); } @@ -113,7 +121,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return builder to add more options to this operation */ - public Options dx(Iterable<Operand<?>> dx) { + public static Options dx(Iterable<? extends Operand<?>> dx) { return new Options().dx(dx); } @@ -129,13 +137,13 @@ public class Gradients implements Op, Iterable<Operand<?>> { public List<Output<?>> dy() { return dy; } - + /** * Returns a symbolic handle to one of the gradient operation output - * <p> - * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * + * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code - * gradients.<Integer>dy(0)} + * gradients.<Float>dy(0)} * * @param <T> The expected element type of the tensors produced by this output. * @param index The index of the output among the gradients added by this operation |