aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 10:09:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 10:09:30 -0700
commit3be04971716fcaf0c11ad9262e60efa428553e14 (patch)
tree91dbe94778c1bc3ecb15188cbce072218a3dd284 /tensorflow/java
parent3fa30255c57b807c3f4b92c7f34c77e8128e911f (diff)
parentc01cfe7ced91dabc19b2392696cf0598a5df70f9 (diff)
Merge pull request #20609 from karllessard:java-add-gradients-with-name
PiperOrigin-RevId: 206779913
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java64
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Scope.java13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java48
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc21
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h8
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java34
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java131
9 files changed, 266 insertions, 68 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 73e210fae0..7ceba3903d 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -292,6 +292,19 @@ tf_java_test(
],
)
+tf_java_test(
+ name = "GradientsTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.op.core.GradientsTest",
+ deps = [
+ ":tensorflow",
+ ":testutil",
+ "@junit",
+ ],
+)
+
filegroup(
name = "processor_test_resources",
srcs = glob([
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 7d19696749..752b49af04 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable {
}
/**
- * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
- * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
- * <p>
- * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
- * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
- * <p>
- * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
- * shapes in {@code y}.
- *
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
+ * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ *
+ * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives
+ * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
+ * {@code y}.
+ *
+ * <p>If {@code dx} is null, the implementation will use dx of {@link
+ * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
+ *
+ * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute
+ * gradients. It must be unique within the provided graph or the operation will fail.
+ *
+ * <p>If {@code prefix} is null, then one will be chosen automatically.
+ *
+ * @param prefix unique string prefix applied before the names of nodes added to the graph to
+ * compute gradients. If null, a default one will be chosen.
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
- public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
@@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable {
dxIndices[i] = dx[i].index();
}
}
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
- // of the gradient operations while the second holds the index of their output
- // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first
+ // holds the native handles of the gradient operations while the second holds the index of
+ // their output e.g. given
+ // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(
+ ref.nativeHandle(),
+ prefix,
+ yHandles,
+ yIndices,
+ xHandles,
+ xIndices,
+ dxHandles,
+ dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable {
/**
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
* i.e., {@code dy/dx_1, dy/dx_2...}
- * <p>
+ * <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
- * a single output and {@code dx} is null.
- *
+ * a single output, {@code dx} is null and {@code prefix} is null.
+ *
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
- return addGradients(new Output<?>[]{y}, x, null);
+ return addGradients(null, new Output<?>[] {y}, x, null);
}
private final Object nativeHandleLock = new Object();
@@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
- long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+ private static native long[] addGradients(
+ long handle,
+ String prefix,
+ long[] inputHandles,
+ int[] inputIndices,
+ long[] outputHandles,
+ int[] outputIndices,
+ long[] gradInputHandles,
+ int[] gradInputIndices);
static {
TensorFlow.init();
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
index 8de2eaeb79..5a233bcc98 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
@@ -135,17 +135,8 @@ public final class Scope {
* }</pre>
*
* <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a
- * set of related operations to the graph by calling other operator building code) you should also
- * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a
- * meaningful name.
- *
- * <pre>{@code
- * public static Stddev create(Scope scope, ...) {
- * // group sub-operations under a common name
- * Scope group = scope.withSubScope("stddev");
- * ... Sqrt.create(group, Mean.create(group, ...))
- * }
- * }</pre>
+ * set of related operations to the graph by calling other operator building code), the provided
+ * name will act as a subscope to all underlying operators.
*
* @param defaultName name for the underlying operator.
* @return unique name for the operator.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index f4671c8af9..eea9dc1c47 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -18,7 +18,6 @@ package org.tensorflow.op.core;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
-
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
@@ -54,32 +53,36 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* Optional attributes for {@link Gradients}
*/
public static class Options {
-
+
/**
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return this option builder
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public Options dx(Iterable<? extends Operand<?>> dx) {
this.dx = dx;
return this;
}
-
- private Iterable<Operand<?>> dx;
-
+
+ private Iterable<? extends Operand<?>> dx;
+
private Options() {
}
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
+ *
* @param scope current graph scope
* @param y outputs of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param options carries optional attributes values
* @return a new instance of {@code Gradients}
*/
- public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope,
+ Iterable<? extends Operand<?>> y,
+ Iterable<? extends Operand<?>> x,
+ Options... options) {
Output<?>[] dx = null;
if (options != null) {
for (Options opts : options) {
@@ -88,16 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> {
}
}
}
- Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
- return new Gradients(Arrays.asList(gradOutputs));
+ Output<?>[] dy =
+ scope
+ .graph()
+ .addGradients(
+ scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(dy));
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
- * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
- * a single output.
- *
+ *
+ * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where
+ * {@code y} is a single output.
+ *
* @param scope current graph scope
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
@@ -105,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @return a new instance of {@code Gradients}
*/
@SuppressWarnings({"unchecked", "rawtypes"})
- public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) {
return create(scope, (Iterable) Arrays.asList(y), x, options);
}
@@ -113,7 +121,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return builder to add more options to this operation
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public static Options dx(Iterable<? extends Operand<?>> dx) {
return new Options().dx(dx);
}
@@ -129,13 +137,13 @@ public class Gradients implements Op, Iterable<Operand<?>> {
public List<Output<?>> dy() {
return dy;
}
-
+
/**
* Returns a symbolic handle to one of the gradient operation output
- * <p>
- * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ *
+ * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call
* this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
- * gradients.<Integer>dy(0)}
+ * gradients.<Float>dy(0)}
*
* @param <T> The expected element type of the tensors produced by this output.
* @param index The index of the output among the gradients added by this operation
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index dac6a345e9..f1744d8769 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
return ret;
}
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
- jlongArray y_handles, jintArray y_indices,
- jlongArray x_handles, jintArray x_indices,
- jlongArray dx_handles, jintArray dx_indices) {
-
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
+ jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
+ jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
TF_Graph* g = requireHandle(env, handle);
if (g == nullptr) return nullptr;
@@ -163,9 +161,16 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
}
if (env->ExceptionCheck()) return nullptr;
+ const char* cprefix = nullptr;
+ if (prefix != nullptr) {
+ cprefix = env->GetStringUTFChars(prefix, nullptr);
+ }
TF_Status* status = TF_NewStatus();
- TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
-
+ TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
+ status, dy.get());
+ if (prefix != nullptr) {
+ env->ReleaseStringUTFChars(prefix, cprefix);
+ }
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return nullptr;
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index 4f87e8d5a7..215695cdfd 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
/*
* Class: org_tensorflow_Graph
* Method: name
- * Signature: (J[J[I[J[I[J[I)[J
+ * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J
*/
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
- jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
- jintArray);
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
+ jintArray, jlongArray, jintArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c2e52c22c6..7c05c1deaf 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -180,8 +179,8 @@ public class GraphTest {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
-
- Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+
+ Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null);
assertNotNull(grad);
assertEquals(1, grad.length);
assertEquals(DataType.FLOAT, grad[0].dataType());
@@ -212,7 +211,7 @@ public class GraphTest {
assertEquals(1, grad0.length);
assertEquals(DataType.FLOAT, grad0[0].dataType());
- Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0]));
assertNotNull(grad1);
assertEquals(1, grad1.length);
assertEquals(DataType.FLOAT, grad1[0].dataType());
@@ -228,6 +227,33 @@ public class GraphTest {
}
}
}
+
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+
+ Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad0[0].op().name().startsWith("gradients/"));
+
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad1[0].op().name().startsWith("gradients_1/"));
+
+ Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad2[0].op().name().startsWith("more_gradients/"));
+
+ Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad3[0].op().name().startsWith("even_more_gradients/"));
+
+ try {
+ g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+ }
+ }
private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 4e84886416..f984c508ee 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -24,7 +24,7 @@ public class TestUtil {
public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
+ public AutoCloseableList(Collection<? extends E> c) {
super(c);
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
new file mode 100644
index 0000000000..3f49790b29
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.Graph;
+import org.tensorflow.Output;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.Tensors;
+import org.tensorflow.TestUtil;
+import org.tensorflow.op.Scope;
+
+@RunWith(JUnit4.class)
+public class GradientsTest {
+
+ @Test
+ public void createGradients() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(2, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithSum() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(1, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
+
+ assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithInitialValues() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0));
+ Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy()));
+
+ assertNotNull(grads1);
+ assertNotNull(grads1.dy());
+ assertEquals(1, grads1.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+ Scope scope = new Scope(g).withSubScope("sub");
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y = TestUtil.square(g, "y", x);
+
+ Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x));
+ assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/"));
+
+ Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x));
+ assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/"));
+ }
+ }
+}