diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 10:09:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 10:09:30 -0700 |
commit | 3be04971716fcaf0c11ad9262e60efa428553e14 (patch) | |
tree | 91dbe94778c1bc3ecb15188cbce072218a3dd284 | |
parent | 3fa30255c57b807c3f4b92c7f34c77e8128e911f (diff) | |
parent | c01cfe7ced91dabc19b2392696cf0598a5df70f9 (diff) |
Merge pull request #20609 from karllessard:java-add-gradients-with-name
PiperOrigin-RevId: 206779913
-rw-r--r-- | tensorflow/c/c_api.cc | 41 | ||||
-rw-r--r-- | tensorflow/c/c_api.h | 34 | ||||
-rw-r--r-- | tensorflow/c/c_api_test.cc | 84 | ||||
-rw-r--r-- | tensorflow/java/BUILD | 13 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Graph.java | 64 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/Scope.java | 13 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java | 48 | ||||
-rw-r--r-- | tensorflow/java/src/main/native/graph_jni.cc | 21 | ||||
-rw-r--r-- | tensorflow/java/src/main/native/graph_jni.h | 8 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 34 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/TestUtil.java | 2 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java | 131 |
12 files changed, 415 insertions, 78 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 10bc8cdbee..19ccb6e71d 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -2389,6 +2390,12 @@ void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy) { + TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy); +} + +void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y, + int ny, TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ status->status = tensorflow::errors::Unimplemented( "Adding gradients is not supported in Android. File a bug at " @@ -2405,9 +2412,29 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, const int first_new_node_id = g->graph.num_node_ids(); + string prefix_cmp; + const char* child_scope_name; + if (prefix == nullptr) { + child_scope_name = "gradients"; + } else { + prefix_cmp = string(prefix) + "/"; + // The operation should fail if the provided name prefix has already been + // used in this graph + for (const auto& pair : g->name_map) { + const string& name = pair.first; + if (name.compare(prefix) == 0 || + tensorflow::str_util::StartsWith(name, prefix_cmp)) { + status->status = InvalidArgument( + "prefix [", prefix, + "] conflicts with existing node in the graph named [", name, "]"); + return; + } + } + child_scope_name = prefix; + } tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope("gradients"); + .NewSubScope(child_scope_name); if (dx != nullptr) { std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny); @@ -2422,6 +2449,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) { Node* n = g->graph.FindNodeId(i); if (n == nullptr) continue; + + // Adding the gradients to the graph can alter the prefix to prevent + // name collisions only if this prefix has not been provided explicitly + // by the user. If it was provided, assert that it remained intact. + if (prefix != nullptr && + !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) { + status->status = tensorflow::errors::Internal( + "BUG: The gradients prefix have been unexpectedly altered when " + "adding the nodes to the graph. This is a bug. Please file an " + "issue at https://github.com/tensorflow/tensorflow/issues."); + return; + } // We have a convoluted scheme here: Using the C++ graph construction API // to add potentially many nodes to the graph without running the checks // (such as uniqueness of the names of nodes) we run with other functions diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 7e97351c8a..850f6ecd63 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1131,6 +1131,7 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, // i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// // `dx` are used as initial gradients (which represent the symbolic partial // derivatives of some loss function `L` w.r.t. `y`). // `dx` must be nullptr or have size `ny`. @@ -1139,6 +1140,12 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // The partial derivatives are returned in `dy`. `dy` should be allocated to // size `nx`. // +// Gradient nodes are automatically named under the "gradients/" prefix. To +// guarantee name uniqueness, subsequent calls to the same graph will +// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ... +// See TF_AddGradientsWithPrefix, which provides a means to specify a custom +// name prefix for operations added to a graph to compute the gradients. +// // WARNING: This function does not yet support all the gradients that python // supports. See // https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md @@ -1147,6 +1154,33 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); +// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, +// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2... +// This is a variant of TF_AddGradients that allows to caller to pass a custom +// name prefix to the operations added to a graph to compute the gradients. +// +// `dx` are used as initial gradients (which represent the symbolic partial +// derivatives of some loss function `L` w.r.t. `y`). +// `dx` must be nullptr or have size `ny`. +// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all +// shapes in `y`. +// The partial derivatives are returned in `dy`. `dy` should be allocated to +// size `nx`. +// `prefix` names the scope into which all gradients operations are being added. +// `prefix` must be unique within the provided graph otherwise this operation +// will fail. If `prefix` is nullptr, the default prefixing behaviour takes +// place, see TF_AddGradients for more details. +// +// WARNING: This function does not yet support all the gradients that python +// supports. See +// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md +// for instructions on how to add C++ more gradients. +TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, + TF_Output* y, int ny, + TF_Output* x, int nx, + TF_Output* dx, TF_Status* status, + TF_Output* dy); + // Create a TF_Function from a TF_Graph // // Params: diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index e674b1623c..aa2a537f03 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1483,8 +1483,8 @@ class CApiGradientsTest : public ::testing::Test { BuildSuccessGraph(inputs, outputs); BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); - AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); - + AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1, + grad_outputs); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); // Compare that the graphs match. @@ -1505,7 +1505,8 @@ class CApiGradientsTest : public ::testing::Test { BuildErrorGraph(inputs, outputs); - AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1, + grad_outputs); string expected_msg = "No gradient defined for op: TestOpWithNoGradient. Please see " @@ -1549,19 +1550,20 @@ class CApiGradientsTest : public ::testing::Test { EXPECT_EQ(*a_data, *b_data); } - void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, - TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + void AddGradients(bool grad_inputs_provided, const char* prefix, + TF_Output* inputs, int ninputs, TF_Output* outputs, + int noutputs, TF_Output* grad_outputs) { if (grad_inputs_provided) { TF_Output grad_inputs[1]; const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; TF_Operation* grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); grad_inputs[0] = TF_Output{grad_inputs_op, 0}; - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, - s_, grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, grad_inputs, s_, grad_outputs); } else { - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, - grad_outputs); + TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, + ninputs, nullptr, s_, grad_outputs); } } @@ -1706,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test { return op; } + void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1, + const char* prefix2 = nullptr) { + TF_Output inputs[2]; + TF_Output outputs[1]; + TF_Output grad_outputs[2]; + + BuildSuccessGraph(inputs, outputs); + + AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs); + if (prefix2 != nullptr) { + AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs); + } + } + TF_Status* s_; TF_Graph* graph_; TF_Graph* expected_graph_; @@ -1725,6 +1741,56 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) { + BuildGraphAndAddGradientsWithPrefixes("gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) { + BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) { + BuildGraphAndAddGradientsWithPrefixes("Const_0"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) { + BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + +TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) { + BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients"); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); +} + void ScalarFloatFromTensor(const TF_Tensor* t, float* f) { ASSERT_TRUE(t != nullptr); ASSERT_EQ(TF_FLOAT, TF_TensorType(t)); diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 73e210fae0..7ceba3903d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -292,6 +292,19 @@ tf_java_test( ], ) +tf_java_test( + name = "GradientsTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.GradientsTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "processor_test_resources", srcs = glob([ diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 7d19696749..752b49af04 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable { } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, - * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} - * <p> - * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function - * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. - * <p> - * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all - * shapes in {@code y}. - * + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., + * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * + * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives + * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of + * {@code y}. + * + * <p>If {@code dx} is null, the implementation will use dx of {@link + * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. + * + * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute + * gradients. It must be unique within the provided graph or the operation will fail. + * + * <p>If {@code prefix} is null, then one will be chosen automatically. + * + * @param prefix unique string prefix applied before the names of nodes added to the graph to + * compute gradients. If null, a default one will be chosen. * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ - public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { Output<?>[] dy = new Output<?>[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; @@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable { dxIndices[i] = dx[i].index(); } } - // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles - // of the gradient operations while the second holds the index of their output - // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // Gradient outputs are returned in two continuous arrays concatenated into one. The first + // holds the native handles of the gradient operations while the second holds the index of + // their output e.g. given + // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] long[] dyHandlesAndIndices = - addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + addGradients( + ref.nativeHandle(), + prefix, + yHandles, + yIndices, + xHandles, + xIndices, + dxHandles, + dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length @@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable { /** * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, * i.e., {@code dy/dx_1, dy/dx_2...} - * <p> + * <p> * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is - * a single output and {@code dx} is null. - * + * a single output, {@code dx} is null and {@code prefix} is null. + * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { - return addGradients(new Output<?>[]{y}, x, null); + return addGradients(null, new Output<?>[] {y}, x, null); } private final Object nativeHandleLock = new Object(); @@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); - private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, - long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + private static native long[] addGradients( + long handle, + String prefix, + long[] inputHandles, + int[] inputIndices, + long[] outputHandles, + int[] outputIndices, + long[] gradInputHandles, + int[] gradInputIndices); static { TensorFlow.init(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 8de2eaeb79..5a233bcc98 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -135,17 +135,8 @@ public final class Scope { * }</pre> * * <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a - * set of related operations to the graph by calling other operator building code) you should also - * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a - * meaningful name. - * - * <pre>{@code - * public static Stddev create(Scope scope, ...) { - * // group sub-operations under a common name - * Scope group = scope.withSubScope("stddev"); - * ... Sqrt.create(group, Mean.create(group, ...)) - * } - * }</pre> + * set of related operations to the graph by calling other operator building code), the provided + * name will act as a subscope to all underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index f4671c8af9..eea9dc1c47 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.Arrays; import java.util.Iterator; import java.util.List; - import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; @@ -54,32 +53,36 @@ public class Gradients implements Op, Iterable<Operand<?>> { * Optional attributes for {@link Gradients} */ public static class Options { - + /** * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return this option builder */ - public Options dx(Iterable<Operand<?>> dx) { + public Options dx(Iterable<? extends Operand<?>> dx) { this.dx = dx; return this; } - - private Iterable<Operand<?>> dx; - + + private Iterable<? extends Operand<?>> dx; + private Options() { } } /** * Adds gradients computation ops to the graph according to scope. - * + * * @param scope current graph scope * @param y outputs of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param options carries optional attributes values * @return a new instance of {@code Gradients} */ - public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, + Iterable<? extends Operand<?>> y, + Iterable<? extends Operand<?>> x, + Options... options) { Output<?>[] dx = null; if (options != null) { for (Options opts : options) { @@ -88,16 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> { } } } - Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); - return new Gradients(Arrays.asList(gradOutputs)); + Output<?>[] dy = + scope + .graph() + .addGradients( + scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(dy)); } /** * Adds gradients computation ops to the graph according to scope. - * - * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is - * a single output. - * + * + * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where + * {@code y} is a single output. + * * @param scope current graph scope * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed @@ -105,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @return a new instance of {@code Gradients} */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) { return create(scope, (Iterable) Arrays.asList(y), x, options); } @@ -113,7 +121,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return builder to add more options to this operation */ - public Options dx(Iterable<Operand<?>> dx) { + public static Options dx(Iterable<? extends Operand<?>> dx) { return new Options().dx(dx); } @@ -129,13 +137,13 @@ public class Gradients implements Op, Iterable<Operand<?>> { public List<Output<?>> dy() { return dy; } - + /** * Returns a symbolic handle to one of the gradient operation output - * <p> - * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * + * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code - * gradients.<Integer>dy(0)} + * gradients.<Float>dy(0)} * * @param <T> The expected element type of the tensors produced by this output. * @param index The index of the output among the gradients added by this operation diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index dac6a345e9..f1744d8769 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { return ret; } -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, - jlongArray y_handles, jintArray y_indices, - jlongArray x_handles, jintArray x_indices, - jlongArray dx_handles, jintArray dx_indices) { - +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv* env, jclass clazz, jlong handle, jstring prefix, + jlongArray y_handles, jintArray y_indices, jlongArray x_handles, + jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { TF_Graph* g = requireHandle(env, handle); if (g == nullptr) return nullptr; @@ -163,9 +161,16 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, } if (env->ExceptionCheck()) return nullptr; + const char* cprefix = nullptr; + if (prefix != nullptr) { + cprefix = env->GetStringUTFChars(prefix, nullptr); + } TF_Status* status = TF_NewStatus(); - TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); - + TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(), + status, dy.get()); + if (prefix != nullptr) { + env->ReleaseStringUTFChars(prefix, cprefix); + } if (!throwExceptionIfNotOK(env, status)) { TF_DeleteStatus(status); return nullptr; diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index 4f87e8d5a7..215695cdfd 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, /* * Class: org_tensorflow_Graph * Method: name - * Signature: (J[J[I[J[I[J[I)[J + * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J */ -JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, - jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, - jintArray); +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray, + jintArray, jlongArray, jintArray); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c2e52c22c6..7c05c1deaf 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; - import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -180,8 +179,8 @@ public class GraphTest { Output<Float> x = TestUtil.placeholder(g, "x", Float.class); Output<Float> y0 = TestUtil.square(g, "y0", x); Output<Float> y1 = TestUtil.square(g, "y1", y0); - - Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + + Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); assertEquals(DataType.FLOAT, grad[0].dataType()); @@ -212,7 +211,7 @@ public class GraphTest { assertEquals(1, grad0.length); assertEquals(DataType.FLOAT, grad0[0].dataType()); - Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); assertEquals(DataType.FLOAT, grad1[0].dataType()); @@ -228,6 +227,33 @@ public class GraphTest { } } } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + + Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad0[0].op().name().startsWith("gradients/")); + + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("gradients_1/")); + + Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad2[0].op().name().startsWith("more_gradients/")); + + Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad3[0].op().name().startsWith("even_more_gradients/")); + + try { + g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + } catch (IllegalArgumentException e) { + // expected exception + } + } + } private static Output<?>[] toArray(Output<?>... outputs) { return outputs; diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 4e84886416..f984c508ee 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -24,7 +24,7 @@ public class TestUtil { public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> implements AutoCloseable { - AutoCloseableList(Collection<? extends E> c) { + public AutoCloseableList(Collection<? extends E> c) { super(c); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java new file mode 100644 index 0000000000..3f49790b29 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.Tensors; +import org.tensorflow.TestUtil; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class GradientsTest { + + @Test + public void createGradients() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(2, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithSum() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + + assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithInitialValues() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0)); + Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + + assertNotNull(grads1); + assertNotNull(grads1.dy()); + assertEquals(1, grads1.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + Scope scope = new Scope(g).withSubScope("sub"); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y = TestUtil.square(g, "y", x); + + Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x)); + assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/")); + + Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x)); + assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/")); + } + } +} |