aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-06-25 22:12:24 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-07-25 21:10:29 -0400
commit2b303fddafec6b96a6868aaa76f55cc392b96586 (patch)
tree8b1da320c69ba5239f8bdd37bfac95cd02704d65
parentb24037513f12a5812a21b7ea92ff904ee9ea6cd8 (diff)
Add scope name to TF_AddGradients
-rw-r--r--tensorflow/c/c_api.cc7
-rw-r--r--tensorflow/c/c_api.h6
-rw-r--r--tensorflow/c/c_api_test.cc44
-rw-r--r--tensorflow/c/while_loop_test.cc4
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java17
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java4
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Scope.java13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java4
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc14
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h6
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java19
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java17
12 files changed, 116 insertions, 39 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 10bc8cdbee..96653154e5 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -2387,8 +2387,9 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status,
void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); }
-void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
- TF_Output* dx, TF_Status* status, TF_Output* dy) {
+void TF_AddGradients(TF_Graph* g, const char* scope_name, TF_Output* y,
+ int ny, TF_Output* x, int nx, TF_Output* dx,
+ TF_Status* status, TF_Output* dy) {
#ifdef __ANDROID__
status->status = tensorflow::errors::Unimplemented(
"Adding gradients is not supported in Android. File a bug at "
@@ -2407,7 +2408,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
tensorflow::Scope scope =
NewInternalScope(&g->graph, &status->status, &g->refiner)
- .NewSubScope("gradients");
+ .NewSubScope(scope_name != nullptr ? scope_name : "gradients");
if (dx != nullptr) {
std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny);
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index c8ae6f2dd1..e896f68ce0 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1138,12 +1138,16 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params);
// shapes in `y`.
// The partial derivatives are returned in `dy`. `dy` should be allocated to
// size `nx`.
+// `scope_name` names the scope (or sub-scope) into which all gradients
+// operations are added. If `scope_name` is nullptr, "gradients" is used by
+// default.
//
// WARNING: This function does not yet support all the gradients that python
// supports. See
// https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md
// for instructions on how to add C++ more gradients.
-TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny,
+TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, const char* scope_name,
+ TF_Output* y, int ny,
TF_Output* x, int nx, TF_Output* dx,
TF_Status* status, TF_Output* dy);
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index e674b1623c..2fe9e91583 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1474,16 +1474,18 @@ class CApiGradientsTest : public ::testing::Test {
TF_DeleteStatus(s_);
}
- void TestGradientsSuccess(bool grad_inputs_provided) {
+ void TestGradientsSuccess(bool grad_inputs_provided,
+ const char* scope_name = nullptr) {
TF_Output inputs[2];
TF_Output outputs[1];
TF_Output grad_outputs[2];
TF_Output expected_grad_outputs[2];
BuildSuccessGraph(inputs, outputs);
- BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs);
+ BuildExpectedGraph(grad_inputs_provided, scope_name, expected_grad_outputs);
- AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs);
+ AddGradients(grad_inputs_provided, scope_name, inputs, 2, outputs, 1,
+ grad_outputs);
EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
@@ -1505,7 +1507,8 @@ class CApiGradientsTest : public ::testing::Test {
BuildErrorGraph(inputs, outputs);
- AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs);
+ AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1,
+ grad_outputs);
string expected_msg =
"No gradient defined for op: TestOpWithNoGradient. Please see "
@@ -1549,19 +1552,20 @@ class CApiGradientsTest : public ::testing::Test {
EXPECT_EQ(*a_data, *b_data);
}
- void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs,
- TF_Output* outputs, int noutputs, TF_Output* grad_outputs) {
+ void AddGradients(bool grad_inputs_provided, const char* scope_name,
+ TF_Output* inputs, int ninputs, TF_Output* outputs,
+ int noutputs, TF_Output* grad_outputs) {
if (grad_inputs_provided) {
TF_Output grad_inputs[1];
const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0};
TF_Operation* grad_inputs_op =
FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs");
grad_inputs[0] = TF_Output{grad_inputs_op, 0};
- TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs,
- s_, grad_outputs);
+ TF_AddGradients(graph_, scope_name, outputs, noutputs, inputs, ninputs,
+ grad_inputs, s_, grad_outputs);
} else {
- TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_,
- grad_outputs);
+ TF_AddGradients(graph_, scope_name, outputs, noutputs, inputs, ninputs,
+ nullptr, s_, grad_outputs);
}
}
@@ -1600,6 +1604,7 @@ class CApiGradientsTest : public ::testing::Test {
}
void BuildExpectedGraph(bool grad_inputs_provided,
+ const char* grad_scope_name,
TF_Output* expected_grad_outputs) {
// The expected graph looks like this if grad_inputs_provided.
// If grad_inputs_provided is false, Const_0 will be a OnesLike op.
@@ -1628,6 +1633,10 @@ class CApiGradientsTest : public ::testing::Test {
//
const float const0_val[] = {1.0, 2.0, 3.0, 4.0};
const float const1_val[] = {1.0, 0.0, 0.0, 1.0};
+ const char* grad_prefix = grad_scope_name;
+ if (grad_scope_name == nullptr) {
+ grad_prefix = "gradients";
+ }
TF_Operation* const0 =
FloatConst2x2(expected_graph_, s_, const0_val, "Const_0");
TF_Operation* const1 =
@@ -1640,13 +1649,14 @@ class CApiGradientsTest : public ::testing::Test {
const float const3_val[] = {1.0, 1.0, 1.0, 1.0};
const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs");
} else {
- const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike");
+ const3 = OnesLike(expected_graph_, s_, matmul,
+ strings::StrCat(grad_prefix, "/OnesLike").c_str());
}
TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1,
- "gradients/MatMul", false, true);
+ strings::StrCat(grad_prefix, "/MatMul").c_str(), false, true);
TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3,
- "gradients/MatMul_1", true, false);
+ strings::StrCat(grad_prefix, "/MatMul_1").c_str(), true, false);
expected_grad_outputs[0] = {matmul1, 0};
expected_grad_outputs[1] = {matmul2, 0};
}
@@ -1717,6 +1727,10 @@ TEST_F(CApiGradientsTest, Gradients_NoGradInputs) {
TestGradientsSuccess(false);
}
+TEST_F(CApiGradientsTest, Gradients_NoGradInputsWithScopeName) {
+ TestGradientsSuccess(false, "gradscope");
+}
+
TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) {
TestGradientsError(true);
}
@@ -1743,11 +1757,11 @@ TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) {
TF_Output outputs[1] = {{xy, 0}};
TF_Output inputs[1] = {{x, 0}};
- TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx);
+ TF_AddGradients(graph_, nullptr, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
inputs[0] = {y, 0};
- TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy);
+ TF_AddGradients(graph_, nullptr, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
TF_SessionOptions* opts = TF_NewSessionOptions();
diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc
index d2d887f32c..12225fd1cb 100644
--- a/tensorflow/c/while_loop_test.cc
+++ b/tensorflow/c/while_loop_test.cc
@@ -431,8 +431,8 @@ TEST_F(CApiWhileLoopTest, Gradients) {
// Create backprop graph
TF_Output grad_output;
- TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
- nullptr, s_, &grad_output);
+ TF_AddGradients(graph_, nullptr, outputs_.data(), outputs_.size(),
+ inputs_.data(), 1, nullptr, s_, &grad_output);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
// Run gradient
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 7d19696749..f2bd3e99a5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -153,12 +153,13 @@ public final class Graph implements AutoCloseable {
* If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
* shapes in {@code y}.
*
+ * @param scopeName name of the subscope into which gradients operations are added. If null, defaults to "gradients".
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
- public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ public Output<?>[] addGradients(String scopeName, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
@@ -185,12 +186,12 @@ public final class Graph implements AutoCloseable {
dxIndices[i] = dx[i].index();
}
}
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
- // of the gradient operations while the second holds the index of their output
- // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native
+ // handles of the gradient operations while the second holds the index of their output e.g. given
+ // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(ref.nativeHandle(), scopeName, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -209,14 +210,14 @@ public final class Graph implements AutoCloseable {
* i.e., {@code dy/dx_1, dy/dx_2...}
* <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
- * a single output and {@code dx} is null.
+ * a single output, {@code dx} is null and {@code scopeName} is null.
*
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
- return addGradients(new Output<?>[]{y}, x, null);
+ return addGradients(null, new Output<?>[]{y}, x, null);
}
private final Object nativeHandleLock = new Object();
@@ -330,7 +331,7 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
+ private static native long[] addGradients(long handle, String scopeName, long[] inputHandles, int[] inputIndices,
long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
static {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
index 2e84cac1ac..92e05d2d6d 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
@@ -56,6 +56,10 @@ final class NameScope {
String actualName = (opName != null) ? opName : name;
return fullyQualify(makeUnique(actualName));
}
+
+ String prefix() {
+ return opPrefix;
+ }
/**
* Create a new, root-level namescope.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
index 8de2eaeb79..d1ab44c3b2 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
@@ -154,6 +154,19 @@ public final class Scope {
public String makeOpName(String defaultName) {
return nameScope.makeOpName(defaultName);
}
+
+ /**
+ * The name prefix of this scope
+ * <p>
+ * This value is the combination of the name of this scope and all of its parents, seperated by a '/', e.g.
+ * <pre>{@code
+ * Scope scope = new Scope(graph);
+ * assertEquals(scope.withSubScope("sub1").withSubScope("sub2").prefix(), "sub1/sub2");
+ * }</pre>
+ */
+ public String prefix() {
+ return nameScope.prefix();
+ }
private Scope(Graph graph, NameScope nameScope) {
this.graph = graph;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index f4671c8af9..d88dc3ba46 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -88,8 +88,8 @@ public class Gradients implements Op, Iterable<Operand<?>> {
}
}
}
- Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
- return new Gradients(Arrays.asList(gradOutputs));
+ Output<?>[] dy = scope.graph().addGradients(scope.prefix(), Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(dy));
}
/**
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index dac6a345e9..a9b2ef6494 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -135,7 +135,7 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
JNIEXPORT jlongArray JNICALL
Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
- jlongArray y_handles, jintArray y_indices,
+ jstring scope_name, jlongArray y_handles, jintArray y_indices,
jlongArray x_handles, jintArray x_indices,
jlongArray dx_handles, jintArray dx_indices) {
@@ -163,9 +163,17 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
}
if (env->ExceptionCheck()) return nullptr;
+ jboolean is_copy;
+ const char* cscope_name = nullptr;
+ if (scope_name != nullptr) {
+ cscope_name = env->GetStringUTFChars(scope_name, &is_copy);
+ }
TF_Status* status = TF_NewStatus();
- TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
-
+ TF_AddGradients(g, cscope_name, y.get(), ny, x.get(), nx, dx.get(), status,
+ dy.get());
+ if (scope_name != nullptr) {
+ env->ReleaseStringUTFChars(scope_name, cscope_name);
+ }
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return nullptr;
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index 4f87e8d5a7..e483bf953b 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
/*
* Class: org_tensorflow_Graph
* Method: name
- * Signature: (J[J[I[J[I[J[I)[J
+ * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J
*/
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
- jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
- jintArray);
+ jclass, jlong, jstring, jlongArray, jintArray, jlongArray, jintArray,
+ jlongArray, jintArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c2e52c22c6..c02336aebe 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -181,7 +181,7 @@ public class GraphTest {
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
- Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+ Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null);
assertNotNull(grad);
assertEquals(1, grad.length);
assertEquals(DataType.FLOAT, grad[0].dataType());
@@ -212,7 +212,7 @@ public class GraphTest {
assertEquals(1, grad0.length);
assertEquals(DataType.FLOAT, grad0[0].dataType());
- Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0]));
assertNotNull(grad1);
assertEquals(1, grad1.length);
assertEquals(DataType.FLOAT, grad1[0].dataType());
@@ -229,6 +229,21 @@ public class GraphTest {
}
}
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+
+ Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad0[0].op().name().startsWith("gradients/"));
+
+ Output<?>[] grad1 = g.addGradients("more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad1[0].op().name().startsWith("more_gradients/"));
+ }
+ }
+
private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
index 125de73554..2057007499 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
@@ -17,10 +17,12 @@ package org.tensorflow.op;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.HashMap;
import java.util.Map;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -181,6 +183,21 @@ public class ScopeTest {
assertEquals(21704, result.intValue());
}
}
+
+ @Test
+ public void prefix() {
+ try (Graph g = new Graph()) {
+ Scope s = new Scope(g);
+ assertNotNull(s.prefix());
+ assertTrue(s.prefix().isEmpty());
+
+ Scope sub1 = s.withSubScope("sub1");
+ assertEquals("sub1", sub1.prefix());
+
+ Scope sub2 = sub1.withSubScope("sub2");
+ assertEquals("sub1/sub2", sub2.prefix());
+ }
+ }
// "handwritten" sample operator classes
private static final class Const<T> {