diff options
Diffstat (limited to 'tensorflow/java')
9 files changed, 266 insertions, 68 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 73e210fae0..7ceba3903d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -292,6 +292,19 @@ tf_java_test( ], ) +tf_java_test( + name = "GradientsTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.GradientsTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "processor_test_resources", srcs = glob([ diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 7d19696749..752b49af04 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable { } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, - * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} - * <p> - * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function - * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. - * <p> - * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all - * shapes in {@code y}. - * + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., + * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * + * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives + * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of + * {@code y}. + * + * <p>If {@code dx} is null, the implementation will use dx of {@link + * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. + * + * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute + * gradients. It must be unique within the provided graph or the operation will fail. + * + * <p>If {@code prefix} is null, then one will be chosen automatically. + * + * @param prefix unique string prefix applied before the names of nodes added to the graph to + * compute gradients. If null, a default one will be chosen. * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ - public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { Output<?>[] dy = new Output<?>[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; @@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable { dxIndices[i] = dx[i].index(); } } - // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles - // of the gradient operations while the second holds the index of their output - // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // Gradient outputs are returned in two continuous arrays concatenated into one. The first + // holds the native handles of the gradient operations while the second holds the index of + // their output e.g. given + // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] long[] dyHandlesAndIndices = - addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + addGradients( + ref.nativeHandle(), + prefix, + yHandles, + yIndices, + xHandles, + xIndices, + dxHandles, + dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length @@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable { /** * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, * i.e., {@code dy/dx_1, dy/dx_2...} - * <p> + * <p> * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is - * a single output and {@code dx} is null. - * + * a single output, {@code dx} is null and {@code prefix} is null. + * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { - return addGradients(new Output<?>[]{y}, x, null); + return addGradients(null, new Output<?>[] {y}, x, null); } private final Object nativeHandleLock = new Object(); @@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); - private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, - long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + private static native long[] addGradients( + long handle, + String prefix, + long[] inputHandles, + int[] inputIndices, + long[] outputHandles, + int[] outputIndices, + long[] gradInputHandles, + int[] gradInputIndices); static { TensorFlow.init(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 8de2eaeb79..5a233bcc98 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -135,17 +135,8 @@ public final class Scope { * }</pre> * * <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a - * set of related operations to the graph by calling other operator building code) you should also - * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a - * meaningful name. - * - * <pre>{@code - * public static Stddev create(Scope scope, ...) { - * // group sub-operations under a common name - * Scope group = scope.withSubScope("stddev"); - * ... Sqrt.create(group, Mean.create(group, ...)) - * } - * }</pre> + * set of related operations to the graph by calling other operator building code), the provided + * name will act as a subscope to all underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index f4671c8af9..eea9dc1c47 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.Arrays; import java.util.Iterator; import java.util.List; - import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; @@ -54,32 +53,36 @@ public class Gradients implements Op, Iterable<Operand<?>> { * Optional attributes for {@link Gradients} */ public static class Options { - + /** * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return this option builder */ - public Options dx(Iterable<Operand<?>> dx) { + public Options dx(Iterable<? extends Operand<?>> dx) { this.dx = dx; return this; } - - private Iterable<Operand<?>> dx; - + + private Iterable<? extends Operand<?>> dx; + private Options() { } } /** * Adds gradients computation ops to the graph according to scope. - * + * * @param scope current graph scope * @param y outputs of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param options carries optional attributes values * @return a new instance of {@code Gradients} */ - public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, + Iterable<? extends Operand<?>> y, + Iterable<? extends Operand<?>> x, + Options... options) { Output<?>[] dx = null; if (options != null) { for (Options opts : options) { @@ -88,16 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> { } } } - Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); - return new Gradients(Arrays.asList(gradOutputs)); + Output<?>[] dy = + scope + .graph() + .addGradients( + scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(dy)); } /** * Adds gradients computation ops to the graph according to scope. - * - * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is - * a single output. - * + * + * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where + * {@code y} is a single output. + * * @param scope current graph scope * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed @@ -105,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @return a new instance of {@code Gradients} */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) { return create(scope, (Iterable) Arrays.asList(y), x, options); } @@ -113,7 +121,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return builder to add more options to this operation */ - public Options dx(Iterable<Operand<?>> dx) { + public static Options dx(Iterable<? extends Operand<?>> dx) { return new Options().dx(dx); } @@ -129,13 +137,13 @@ public class Gradients implements Op, Iterable<Operand<?>> { public List<Output<?>> dy() { return dy; } - + /** * Returns a symbolic handle to one of the gradient operation output - * <p> - * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * + * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code - * gradients.<Integer>dy(0)} + * gradients.<Float>dy(0)} * * @param <T> The expected element type of the tensors produced by this output. * @param index The index of the output among the gradients added by this operation diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index dac6a345e9..f1744d8769 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { return ret; } -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, - jlongArray y_handles, jintArray y_indices, - jlongArray x_handles, jintArray x_indices, - jlongArray dx_handles, jintArray dx_indices) { - +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv* env, jclass clazz, jlong handle, jstring prefix, + jlongArray y_handles, jintArray y_indices, jlongArray x_handles, + jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { TF_Graph* g = requireHandle(env, handle); if (g == nullptr) return nullptr; @@ -163,9 +161,16 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, } if (env->ExceptionCheck()) return nullptr; + const char* cprefix = nullptr; + if (prefix != nullptr) { + cprefix = env->GetStringUTFChars(prefix, nullptr); + } TF_Status* status = TF_NewStatus(); - TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); - + TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(), + status, dy.get()); + if (prefix != nullptr) { + env->ReleaseStringUTFChars(prefix, cprefix); + } if (!throwExceptionIfNotOK(env, status)) { TF_DeleteStatus(status); return nullptr; diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index 4f87e8d5a7..215695cdfd 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, /* * Class: org_tensorflow_Graph * Method: name - * Signature: (J[J[I[J[I[J[I)[J + * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J */ -JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, - jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, - jintArray); +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray, + jintArray, jlongArray, jintArray); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c2e52c22c6..7c05c1deaf 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; - import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -180,8 +179,8 @@ public class GraphTest { Output<Float> x = TestUtil.placeholder(g, "x", Float.class); Output<Float> y0 = TestUtil.square(g, "y0", x); Output<Float> y1 = TestUtil.square(g, "y1", y0); - - Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + + Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); assertEquals(DataType.FLOAT, grad[0].dataType()); @@ -212,7 +211,7 @@ public class GraphTest { assertEquals(1, grad0.length); assertEquals(DataType.FLOAT, grad0[0].dataType()); - Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); assertEquals(DataType.FLOAT, grad1[0].dataType()); @@ -228,6 +227,33 @@ public class GraphTest { } } } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + + Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad0[0].op().name().startsWith("gradients/")); + + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("gradients_1/")); + + Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad2[0].op().name().startsWith("more_gradients/")); + + Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad3[0].op().name().startsWith("even_more_gradients/")); + + try { + g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + } catch (IllegalArgumentException e) { + // expected exception + } + } + } private static Output<?>[] toArray(Output<?>... outputs) { return outputs; diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 4e84886416..f984c508ee 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -24,7 +24,7 @@ public class TestUtil { public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> implements AutoCloseable { - AutoCloseableList(Collection<? extends E> c) { + public AutoCloseableList(Collection<? extends E> c) { super(c); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java new file mode 100644 index 0000000000..3f49790b29 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.Tensors; +import org.tensorflow.TestUtil; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class GradientsTest { + + @Test + public void createGradients() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(2, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithSum() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + + assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithInitialValues() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0)); + Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + + assertNotNull(grads1); + assertNotNull(grads1.dy()); + assertEquals(1, grads1.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + Scope scope = new Scope(g).withSubScope("sub"); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y = TestUtil.square(g, "y", x); + + Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x)); + assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/")); + + Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x)); + assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/")); + } + } +} |