aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_api.cc41
-rw-r--r--tensorflow/c/c_api.h34
-rw-r--r--tensorflow/c/c_api_test.cc84
-rw-r--r--tensorflow/java/BUILD13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java64
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Scope.java13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java48
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc21
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h8
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java34
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java131
12 files changed, 415 insertions, 78 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 10bc8cdbee..19ccb6e71d 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -52,6 +52,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -2389,6 +2390,12 @@ void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
TF_Output* dx, TF_Status* status, TF_Output* dy) {
+ TF_AddGradientsWithPrefix(g, nullptr, y, ny, x, nx, dx, status, dy);
+}
+
+void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix, TF_Output* y,
+ int ny, TF_Output* x, int nx, TF_Output* dx,
+ TF_Status* status, TF_Output* dy) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Adding gradients is not supported in Android. File a bug at "
@@ -2405,9 +2412,29 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
const int first_new_node_id = g->graph.num_node_ids();
+ string prefix_cmp;
+ const char* child_scope_name;
+ if (prefix == nullptr) {
+ child_scope_name = "gradients";
+ } else {
+ prefix_cmp = string(prefix) + "/";
+ // The operation should fail if the provided name prefix has already been
+ // used in this graph
+ for (const auto& pair : g->name_map) {
+ const string& name = pair.first;
+ if (name.compare(prefix) == 0 ||
+ tensorflow::str_util::StartsWith(name, prefix_cmp)) {
+ status->status = InvalidArgument(
+ "prefix [", prefix,
+ "] conflicts with existing node in the graph named [", name, "]");
+ return;
+ }
+ }
+ child_scope_name = prefix;
+ }
tensorflow::Scope scope =
NewInternalScope(&g->graph, &status->status, &g->refiner)
- .NewSubScope("gradients");
+ .NewSubScope(child_scope_name);
if (dx != nullptr) {
std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
@@ -2422,6 +2449,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
+
+ // Adding the gradients to the graph can alter the prefix to prevent
+ // name collisions only if this prefix has not been provided explicitly
+ // by the user. If it was provided, assert that it remained intact.
+ if (prefix != nullptr &&
+ !tensorflow::str_util::StartsWith(n->name(), prefix_cmp)) {
+ status->status = tensorflow::errors::Internal(
+ "BUG: The gradients prefix have been unexpectedly altered when "
+ "adding the nodes to the graph. This is a bug. Please file an "
+ "issue at https://github.com/tensorflow/tensorflow/issues.");
+ return;
+ }
// We have a convoluted scheme here: Using the C++ graph construction API
// to add potentially many nodes to the graph without running the checks
// (such as uniqueness of the names of nodes) we run with other functions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 7e97351c8a..850f6ecd63 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1131,6 +1131,7 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params);
// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
+//
// `dx` are used as initial gradients (which represent the symbolic partial
// derivatives of some loss function `L` w.r.t. `y`).
// `dx` must be nullptr or have size `ny`.
@@ -1139,6 +1140,12 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params);
// The partial derivatives are returned in `dy`. `dy` should be allocated to
// size `nx`.
//
+// Gradient nodes are automatically named under the "gradients/" prefix. To
+// guarantee name uniqueness, subsequent calls to the same graph will
+// append an incremental tag to the prefix: "gradients_1/", "gradients_2/", ...
+// See TF_AddGradientsWithPrefix, which provides a means to specify a custom
+// name prefix for operations added to a graph to compute the gradients.
+//
// WARNING: This function does not yet support all the gradients that python
// supports. See
// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
@@ -1147,6 +1154,33 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
TF_Output* x, int nx, TF_Output* dx,
TF_Status* status, TF_Output* dy);
+// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s,
+// i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
+// This is a variant of TF_AddGradients that allows to caller to pass a custom
+// name prefix to the operations added to a graph to compute the gradients.
+//
+// `dx` are used as initial gradients (which represent the symbolic partial
+// derivatives of some loss function `L` w.r.t. `y`).
+// `dx` must be nullptr or have size `ny`.
+// If `dx` is nullptr, the implementation will use dx of `OnesLike` for all
+// shapes in `y`.
+// The partial derivatives are returned in `dy`. `dy` should be allocated to
+// size `nx`.
+// `prefix` names the scope into which all gradients operations are being added.
+// `prefix` must be unique within the provided graph otherwise this operation
+// will fail. If `prefix` is nullptr, the default prefixing behaviour takes
+// place, see TF_AddGradients for more details.
+//
+// WARNING: This function does not yet support all the gradients that python
+// supports. See
+// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
+// for instructions on how to add C++ more gradients.
+TF_CAPI_EXPORT void TF_AddGradientsWithPrefix(TF_Graph* g, const char* prefix,
+ TF_Output* y, int ny,
+ TF_Output* x, int nx,
+ TF_Output* dx, TF_Status* status,
+ TF_Output* dy);
+
// Create a TF_Function from a TF_Graph
//
// Params:
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index e674b1623c..aa2a537f03 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1483,8 +1483,8 @@ class CApiGradientsTest : public ::testing::Test {
BuildSuccessGraph(inputs, outputs);
BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
- AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs);
-
+ AddGradients(grad_inputs_provided, nullptr, inputs, 2, outputs, 1,
+ grad_outputs);
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Compare that the graphs match.
@@ -1505,7 +1505,8 @@ class CApiGradientsTest : public ::testing::Test {
BuildErrorGraph(inputs, outputs);
- AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs);
+ AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1,
+ grad_outputs);
string expected_msg =
"No gradient defined for op: TestOpWithNoGradient. Please see "
@@ -1549,19 +1550,20 @@ class CApiGradientsTest : public ::testing::Test {
EXPECT_EQ(*a_data, *b_data);
}
- void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs,
- TF_Output* outputs, int noutputs, TF_Output* grad_outputs) {
+ void AddGradients(bool grad_inputs_provided, const char* prefix,
+ TF_Output* inputs, int ninputs, TF_Output* outputs,
+ int noutputs, TF_Output* grad_outputs) {
if (grad_inputs_provided) {
TF_Output grad_inputs[1];
const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
TF_Operation* grad_inputs_op =
FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
grad_inputs[0] = TF_Output{grad_inputs_op, 0};
- TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs,
- s_, grad_outputs);
+ TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
+ ninputs, grad_inputs, s_, grad_outputs);
} else {
- TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_,
- grad_outputs);
+ TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
+ ninputs, nullptr, s_, grad_outputs);
}
}
@@ -1706,6 +1708,20 @@ class CApiGradientsTest : public ::testing::Test {
return op;
}
+ void BuildGraphAndAddGradientsWithPrefixes(const char* prefix1,
+ const char* prefix2 = nullptr) {
+ TF_Output inputs[2];
+ TF_Output outputs[1];
+ TF_Output grad_outputs[2];
+
+ BuildSuccessGraph(inputs, outputs);
+
+ AddGradients(false, prefix1, inputs, 2, outputs, 1, grad_outputs);
+ if (prefix2 != nullptr) {
+ AddGradients(false, prefix2, inputs, 2, outputs, 1, grad_outputs);
+ }
+ }
+
TF_Status* s_;
TF_Graph* graph_;
TF_Graph* expected_graph_;
@@ -1725,6 +1741,56 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
+TEST_F(CApiGradientsTest, GradientsPrefix_PrefixIsOk) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithDistinctPrefixes) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients_1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInSameScope) {
+ BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope/gradients_1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsInDifferentScopes) {
+ BuildGraphAndAddGradientsWithPrefixes("scope/gradients", "scope_1/gradients");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsSubScopeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/sub");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_PrefixMatchesExistingNodeName) {
+ BuildGraphAndAddGradientsWithPrefixes("Const_0");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_TwoGradientsWithIdenticalPrefixes) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsMatchingNodeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients", "gradients/MatMul");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_1stGradientsMatchingNodeOf2nd) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients/MatMul", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
+TEST_F(CApiGradientsTest, GradientsPrefix_2ndGradientsAsParentScopeOf1st) {
+ BuildGraphAndAddGradientsWithPrefixes("gradients/sub", "gradients");
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
+}
+
void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
ASSERT_TRUE(t != nullptr);
ASSERT_EQ(TF_FLOAT, TF_TensorType(t));
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 73e210fae0..7ceba3903d 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -292,6 +292,19 @@ tf_java_test(
],
)
+tf_java_test(
+ name = "GradientsTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.op.core.GradientsTest",
+ deps = [
+ ":tensorflow",
+ ":testutil",
+ "@junit",
+ ],
+)
+
filegroup(
name = "processor_test_resources",
srcs = glob([
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 7d19696749..752b49af04 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable {
}
/**
- * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
- * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
- * <p>
- * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
- * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
- * <p>
- * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
- * shapes in {@code y}.
- *
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
+ * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ *
+ * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives
+ * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
+ * {@code y}.
+ *
+ * <p>If {@code dx} is null, the implementation will use dx of {@link
+ * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
+ *
+ * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute
+ * gradients. It must be unique within the provided graph or the operation will fail.
+ *
+ * <p>If {@code prefix} is null, then one will be chosen automatically.
+ *
+ * @param prefix unique string prefix applied before the names of nodes added to the graph to
+ * compute gradients. If null, a default one will be chosen.
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
- public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
@@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable {
dxIndices[i] = dx[i].index();
}
}
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
- // of the gradient operations while the second holds the index of their output
- // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first
+ // holds the native handles of the gradient operations while the second holds the index of
+ // their output e.g. given
+ // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(
+ ref.nativeHandle(),
+ prefix,
+ yHandles,
+ yIndices,
+ xHandles,
+ xIndices,
+ dxHandles,
+ dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable {
/**
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
* i.e., {@code dy/dx_1, dy/dx_2...}
- * <p>
+ * <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
- * a single output and {@code dx} is null.
- *
+ * a single output, {@code dx} is null and {@code prefix} is null.
+ *
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
- return addGradients(new Output<?>[]{y}, x, null);
+ return addGradients(null, new Output<?>[] {y}, x, null);
}
private final Object nativeHandleLock = new Object();
@@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
- long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+ private static native long[] addGradients(
+ long handle,
+ String prefix,
+ long[] inputHandles,
+ int[] inputIndices,
+ long[] outputHandles,
+ int[] outputIndices,
+ long[] gradInputHandles,
+ int[] gradInputIndices);
static {
TensorFlow.init();
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
index 8de2eaeb79..5a233bcc98 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
@@ -135,17 +135,8 @@ public final class Scope {
* }</pre>
*
* <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a
- * set of related operations to the graph by calling other operator building code) you should also
- * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a
- * meaningful name.
- *
- * <pre>{@code
- * public static Stddev create(Scope scope, ...) {
- * // group sub-operations under a common name
- * Scope group = scope.withSubScope("stddev");
- * ... Sqrt.create(group, Mean.create(group, ...))
- * }
- * }</pre>
+ * set of related operations to the graph by calling other operator building code), the provided
+ * name will act as a subscope to all underlying operators.
*
* @param defaultName name for the underlying operator.
* @return unique name for the operator.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index f4671c8af9..eea9dc1c47 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -18,7 +18,6 @@ package org.tensorflow.op.core;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
-
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
@@ -54,32 +53,36 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* Optional attributes for {@link Gradients}
*/
public static class Options {
-
+
/**
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return this option builder
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public Options dx(Iterable<? extends Operand<?>> dx) {
this.dx = dx;
return this;
}
-
- private Iterable<Operand<?>> dx;
-
+
+ private Iterable<? extends Operand<?>> dx;
+
private Options() {
}
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
+ *
* @param scope current graph scope
* @param y outputs of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param options carries optional attributes values
* @return a new instance of {@code Gradients}
*/
- public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope,
+ Iterable<? extends Operand<?>> y,
+ Iterable<? extends Operand<?>> x,
+ Options... options) {
Output<?>[] dx = null;
if (options != null) {
for (Options opts : options) {
@@ -88,16 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> {
}
}
}
- Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
- return new Gradients(Arrays.asList(gradOutputs));
+ Output<?>[] dy =
+ scope
+ .graph()
+ .addGradients(
+ scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(dy));
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
- * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
- * a single output.
- *
+ *
+ * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where
+ * {@code y} is a single output.
+ *
* @param scope current graph scope
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
@@ -105,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @return a new instance of {@code Gradients}
*/
@SuppressWarnings({"unchecked", "rawtypes"})
- public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) {
return create(scope, (Iterable) Arrays.asList(y), x, options);
}
@@ -113,7 +121,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return builder to add more options to this operation
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public static Options dx(Iterable<? extends Operand<?>> dx) {
return new Options().dx(dx);
}
@@ -129,13 +137,13 @@ public class Gradients implements Op, Iterable<Operand<?>> {
public List<Output<?>> dy() {
return dy;
}
-
+
/**
* Returns a symbolic handle to one of the gradient operation output
- * <p>
- * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ *
+ * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call
* this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
- * gradients.<Integer>dy(0)}
+ * gradients.<Float>dy(0)}
*
* @param <T> The expected element type of the tensors produced by this output.
* @param index The index of the output among the gradients added by this operation
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index dac6a345e9..f1744d8769 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
return ret;
}
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
- jlongArray y_handles, jintArray y_indices,
- jlongArray x_handles, jintArray x_indices,
- jlongArray dx_handles, jintArray dx_indices) {
-
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
+ jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
+ jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
TF_Graph* g = requireHandle(env, handle);
if (g == nullptr) return nullptr;
@@ -163,9 +161,16 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
}
if (env->ExceptionCheck()) return nullptr;
+ const char* cprefix = nullptr;
+ if (prefix != nullptr) {
+ cprefix = env->GetStringUTFChars(prefix, nullptr);
+ }
TF_Status* status = TF_NewStatus();
- TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
-
+ TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
+ status, dy.get());
+ if (prefix != nullptr) {
+ env->ReleaseStringUTFChars(prefix, cprefix);
+ }
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return nullptr;
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index 4f87e8d5a7..215695cdfd 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
/*
* Class: org_tensorflow_Graph
* Method: name
- * Signature: (J[J[I[J[I[J[I)[J
+ * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J
*/
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
- jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
- jintArray);
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
+ jintArray, jlongArray, jintArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c2e52c22c6..7c05c1deaf 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -180,8 +179,8 @@ public class GraphTest {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
-
- Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+
+ Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null);
assertNotNull(grad);
assertEquals(1, grad.length);
assertEquals(DataType.FLOAT, grad[0].dataType());
@@ -212,7 +211,7 @@ public class GraphTest {
assertEquals(1, grad0.length);
assertEquals(DataType.FLOAT, grad0[0].dataType());
- Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0]));
assertNotNull(grad1);
assertEquals(1, grad1.length);
assertEquals(DataType.FLOAT, grad1[0].dataType());
@@ -228,6 +227,33 @@ public class GraphTest {
}
}
}
+
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+
+ Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad0[0].op().name().startsWith("gradients/"));
+
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad1[0].op().name().startsWith("gradients_1/"));
+
+ Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad2[0].op().name().startsWith("more_gradients/"));
+
+ Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad3[0].op().name().startsWith("even_more_gradients/"));
+
+ try {
+ g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+ }
+ }
private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 4e84886416..f984c508ee 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -24,7 +24,7 @@ public class TestUtil {
public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
+ public AutoCloseableList(Collection<? extends E> c) {
super(c);
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
new file mode 100644
index 0000000000..3f49790b29
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.Graph;
+import org.tensorflow.Output;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.Tensors;
+import org.tensorflow.TestUtil;
+import org.tensorflow.op.Scope;
+
+@RunWith(JUnit4.class)
+public class GradientsTest {
+
+ @Test
+ public void createGradients() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(2, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithSum() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(1, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
+
+ assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithInitialValues() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0));
+ Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy()));
+
+ assertNotNull(grads1);
+ assertNotNull(grads1.dy());
+ assertEquals(1, grads1.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+ Scope scope = new Scope(g).withSubScope("sub");
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y = TestUtil.square(g, "y", x);
+
+ Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x));
+ assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/"));
+
+ Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x));
+ assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/"));
+ }
+ }
+}