diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-06-25 22:12:24 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-07-25 21:10:29 -0400 |
commit | 2b303fddafec6b96a6868aaa76f55cc392b96586 (patch) | |
tree | 8b1da320c69ba5239f8bdd37bfac95cd02704d65 | |
parent | b24037513f12a5812a21b7ea92ff904ee9ea6cd8 (diff) |
Add scope name to TF_AddGradients
-rw-r--r-- | tensorflow/c/c_api.cc | 7 | ||||
-rw-r--r-- | tensorflow/c/c_api.h | 6 | ||||
-rw-r--r-- | tensorflow/c/c_api_test.cc | 44 | ||||
-rw-r--r-- | tensorflow/c/while_loop_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/Graph.java | 17 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java | 4 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/Scope.java | 13 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java | 4 | ||||
-rw-r--r-- | tensorflow/java/src/main/native/graph_jni.cc | 14 | ||||
-rw-r--r-- | tensorflow/java/src/main/native/graph_jni.h | 6 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/GraphTest.java | 19 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java | 17 |
12 files changed, 116 insertions, 39 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 10bc8cdbee..96653154e5 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2387,8 +2387,9 @@ void TF_FinishWhile(const TF_WhileParams* params, TF_Status* status, void TF_AbortWhile(const TF_WhileParams* params) { FreeWhileResources(params); } -void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, - TF_Output* dx, TF_Status* status, TF_Output* dy) { +void TF_AddGradients(TF_Graph* g, const char* scope_name, TF_Output* y, + int ny, TF_Output* x, int nx, TF_Output* dx, + TF_Status* status, TF_Output* dy) { #ifdef __ANDROID__ status->status = tensorflow::errors::Unimplemented( "Adding gradients is not supported in Android. File a bug at " @@ -2407,7 +2408,7 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx, tensorflow::Scope scope = NewInternalScope(&g->graph, &status->status, &g->refiner) - .NewSubScope("gradients"); + .NewSubScope(scope_name != nullptr ? scope_name : "gradients"); if (dx != nullptr) { std::vector<tensorflow::Output> dx_arg = OutputsFromTFOutputs(dx, ny); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index c8ae6f2dd1..e896f68ce0 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1138,12 +1138,16 @@ TF_CAPI_EXPORT extern void TF_AbortWhile(const TF_WhileParams* params); // shapes in `y`. // The partial derivatives are returned in `dy`. `dy` should be allocated to // size `nx`. +// `scope_name` names the scope (or sub-scope) into which all gradients +// operations are added. If `scope_name` is nullptr, "gradients" is used by +// default. // // WARNING: This function does not yet support all the gradients that python // supports. See // https://www.tensorflow.org/code/tensorflow/cc/gradients/README.md // for instructions on how to add C++ more gradients. -TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, +TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, const char* scope_name, + TF_Output* y, int ny, TF_Output* x, int nx, TF_Output* dx, TF_Status* status, TF_Output* dy); diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index e674b1623c..2fe9e91583 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1474,16 +1474,18 @@ class CApiGradientsTest : public ::testing::Test { TF_DeleteStatus(s_); } - void TestGradientsSuccess(bool grad_inputs_provided) { + void TestGradientsSuccess(bool grad_inputs_provided, + const char* scope_name = nullptr) { TF_Output inputs[2]; TF_Output outputs[1]; TF_Output grad_outputs[2]; TF_Output expected_grad_outputs[2]; BuildSuccessGraph(inputs, outputs); - BuildExpectedGraph(grad_inputs_provided, expected_grad_outputs); + BuildExpectedGraph(grad_inputs_provided, scope_name, expected_grad_outputs); - AddGradients(grad_inputs_provided, inputs, 2, outputs, 1, grad_outputs); + AddGradients(grad_inputs_provided, scope_name, inputs, 2, outputs, 1, + grad_outputs); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); @@ -1505,7 +1507,8 @@ class CApiGradientsTest : public ::testing::Test { BuildErrorGraph(inputs, outputs); - AddGradients(grad_inputs_provided, inputs, 1, outputs, 1, grad_outputs); + AddGradients(grad_inputs_provided, nullptr, inputs, 1, outputs, 1, + grad_outputs); string expected_msg = "No gradient defined for op: TestOpWithNoGradient. Please see " @@ -1549,19 +1552,20 @@ class CApiGradientsTest : public ::testing::Test { EXPECT_EQ(*a_data, *b_data); } - void AddGradients(bool grad_inputs_provided, TF_Output* inputs, int ninputs, - TF_Output* outputs, int noutputs, TF_Output* grad_outputs) { + void AddGradients(bool grad_inputs_provided, const char* scope_name, + TF_Output* inputs, int ninputs, TF_Output* outputs, + int noutputs, TF_Output* grad_outputs) { if (grad_inputs_provided) { TF_Output grad_inputs[1]; const float grad_inputs_val[] = {1.0, 1.0, 1.0, 1.0}; TF_Operation* grad_inputs_op = FloatConst2x2(graph_, s_, grad_inputs_val, "GradInputs"); grad_inputs[0] = TF_Output{grad_inputs_op, 0}; - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, grad_inputs, - s_, grad_outputs); + TF_AddGradients(graph_, scope_name, outputs, noutputs, inputs, ninputs, + grad_inputs, s_, grad_outputs); } else { - TF_AddGradients(graph_, outputs, noutputs, inputs, ninputs, nullptr, s_, - grad_outputs); + TF_AddGradients(graph_, scope_name, outputs, noutputs, inputs, ninputs, + nullptr, s_, grad_outputs); } } @@ -1600,6 +1604,7 @@ class CApiGradientsTest : public ::testing::Test { } void BuildExpectedGraph(bool grad_inputs_provided, + const char* grad_scope_name, TF_Output* expected_grad_outputs) { // The expected graph looks like this if grad_inputs_provided. // If grad_inputs_provided is false, Const_0 will be a OnesLike op. @@ -1628,6 +1633,10 @@ class CApiGradientsTest : public ::testing::Test { // const float const0_val[] = {1.0, 2.0, 3.0, 4.0}; const float const1_val[] = {1.0, 0.0, 0.0, 1.0}; + const char* grad_prefix = grad_scope_name; + if (grad_scope_name == nullptr) { + grad_prefix = "gradients"; + } TF_Operation* const0 = FloatConst2x2(expected_graph_, s_, const0_val, "Const_0"); TF_Operation* const1 = @@ -1640,13 +1649,14 @@ class CApiGradientsTest : public ::testing::Test { const float const3_val[] = {1.0, 1.0, 1.0, 1.0}; const3 = FloatConst2x2(expected_graph_, s_, const3_val, "GradInputs"); } else { - const3 = OnesLike(expected_graph_, s_, matmul, "gradients/OnesLike"); + const3 = OnesLike(expected_graph_, s_, matmul, + strings::StrCat(grad_prefix, "/OnesLike").c_str()); } TF_Operation* matmul1 = MatMul(expected_graph_, s_, const3, const1, - "gradients/MatMul", false, true); + strings::StrCat(grad_prefix, "/MatMul").c_str(), false, true); TF_Operation* matmul2 = MatMul(expected_graph_, s_, const0, const3, - "gradients/MatMul_1", true, false); + strings::StrCat(grad_prefix, "/MatMul_1").c_str(), true, false); expected_grad_outputs[0] = {matmul1, 0}; expected_grad_outputs[1] = {matmul2, 0}; } @@ -1717,6 +1727,10 @@ TEST_F(CApiGradientsTest, Gradients_NoGradInputs) { TestGradientsSuccess(false); } +TEST_F(CApiGradientsTest, Gradients_NoGradInputsWithScopeName) { + TestGradientsSuccess(false, "gradscope"); +} + TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_GradInputs) { TestGradientsError(true); } @@ -1743,11 +1757,11 @@ TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) { TF_Output outputs[1] = {{xy, 0}}; TF_Output inputs[1] = {{x, 0}}; - TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx); + TF_AddGradients(graph_, nullptr, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); inputs[0] = {y, 0}; - TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy); + TF_AddGradients(graph_, nullptr, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_SessionOptions* opts = TF_NewSessionOptions(); diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index d2d887f32c..12225fd1cb 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -431,8 +431,8 @@ TEST_F(CApiWhileLoopTest, Gradients) { // Create backprop graph TF_Output grad_output; - TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1, - nullptr, s_, &grad_output); + TF_AddGradients(graph_, nullptr, outputs_.data(), outputs_.size(), + inputs_.data(), 1, nullptr, s_, &grad_output); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); // Run gradient diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 7d19696749..f2bd3e99a5 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -153,12 +153,13 @@ public final class Graph implements AutoCloseable { * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all * shapes in {@code y}. * + * @param scopeName name of the subscope into which gradients operations are added. If null, defaults to "gradients". * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ - public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + public Output<?>[] addGradients(String scopeName, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { Output<?>[] dy = new Output<?>[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; @@ -185,12 +186,12 @@ public final class Graph implements AutoCloseable { dxIndices[i] = dx[i].index(); } } - // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles - // of the gradient operations while the second holds the index of their output - // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native + // handles of the gradient operations while the second holds the index of their output e.g. given + // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] long[] dyHandlesAndIndices = - addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + addGradients(ref.nativeHandle(), scopeName, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length @@ -209,14 +210,14 @@ public final class Graph implements AutoCloseable { * i.e., {@code dy/dx_1, dy/dx_2...} * <p> * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is - * a single output and {@code dx} is null. + * a single output, {@code dx} is null and {@code scopeName} is null. * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { - return addGradients(new Output<?>[]{y}, x, null); + return addGradients(null, new Output<?>[]{y}, x, null); } private final Object nativeHandleLock = new Object(); @@ -330,7 +331,7 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); - private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, + private static native long[] addGradients(long handle, String scopeName, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); static { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java index 2e84cac1ac..92e05d2d6d 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java @@ -56,6 +56,10 @@ final class NameScope { String actualName = (opName != null) ? opName : name; return fullyQualify(makeUnique(actualName)); } + + String prefix() { + return opPrefix; + } /** * Create a new, root-level namescope. diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 8de2eaeb79..d1ab44c3b2 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -154,6 +154,19 @@ public final class Scope { public String makeOpName(String defaultName) { return nameScope.makeOpName(defaultName); } + + /** + * The name prefix of this scope + * <p> + * This value is the combination of the name of this scope and all of its parents, seperated by a '/', e.g. + * <pre>{@code + * Scope scope = new Scope(graph); + * assertEquals(scope.withSubScope("sub1").withSubScope("sub2").prefix(), "sub1/sub2"); + * }</pre> + */ + public String prefix() { + return nameScope.prefix(); + } private Scope(Graph graph, NameScope nameScope) { this.graph = graph; diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index f4671c8af9..d88dc3ba46 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -88,8 +88,8 @@ public class Gradients implements Op, Iterable<Operand<?>> { } } } - Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); - return new Gradients(Arrays.asList(gradOutputs)); + Output<?>[] dy = scope.graph().addGradients(scope.prefix(), Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(dy)); } /** diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index dac6a345e9..a9b2ef6494 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -135,7 +135,7 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, - jlongArray y_handles, jintArray y_indices, + jstring scope_name, jlongArray y_handles, jintArray y_indices, jlongArray x_handles, jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { @@ -163,9 +163,17 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, } if (env->ExceptionCheck()) return nullptr; + jboolean is_copy; + const char* cscope_name = nullptr; + if (scope_name != nullptr) { + cscope_name = env->GetStringUTFChars(scope_name, &is_copy); + } TF_Status* status = TF_NewStatus(); - TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); - + TF_AddGradients(g, cscope_name, y.get(), ny, x.get(), nx, dx.get(), status, + dy.get()); + if (scope_name != nullptr) { + env->ReleaseStringUTFChars(scope_name, cscope_name); + } if (!throwExceptionIfNotOK(env, status)) { TF_DeleteStatus(status); return nullptr; diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index 4f87e8d5a7..e483bf953b 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, /* * Class: org_tensorflow_Graph * Method: name - * Signature: (J[J[I[J[I[J[I)[J + * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J */ JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, - jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, - jintArray); + jclass, jlong, jstring, jlongArray, jintArray, jlongArray, jintArray, + jlongArray, jintArray); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c2e52c22c6..c02336aebe 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -181,7 +181,7 @@ public class GraphTest { Output<Float> y0 = TestUtil.square(g, "y0", x); Output<Float> y1 = TestUtil.square(g, "y1", y0); - Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); assertEquals(DataType.FLOAT, grad[0].dataType()); @@ -212,7 +212,7 @@ public class GraphTest { assertEquals(1, grad0.length); assertEquals(DataType.FLOAT, grad0[0].dataType()); - Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); assertEquals(DataType.FLOAT, grad1[0].dataType()); @@ -229,6 +229,21 @@ public class GraphTest { } } + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + + Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad0[0].op().name().startsWith("gradients/")); + + Output<?>[] grad1 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("more_gradients/")); + } + } + private static Output<?>[] toArray(Output<?>... outputs) { return outputs; } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java index 125de73554..2057007499 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java @@ -17,10 +17,12 @@ package org.tensorflow.op; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import java.util.HashMap; import java.util.Map; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -181,6 +183,21 @@ public class ScopeTest { assertEquals(21704, result.intValue()); } } + + @Test + public void prefix() { + try (Graph g = new Graph()) { + Scope s = new Scope(g); + assertNotNull(s.prefix()); + assertTrue(s.prefix().isEmpty()); + + Scope sub1 = s.withSubScope("sub1"); + assertEquals("sub1", sub1.prefix()); + + Scope sub2 = sub1.withSubScope("sub2"); + assertEquals("sub1/sub2", sub2.prefix()); + } + } // "handwritten" sample operator classes private static final class Const<T> { |