aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-06-25 22:12:24 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-07-25 21:10:29 -0400
commit2b303fddafec6b96a6868aaa76f55cc392b96586 (patch)
tree8b1da320c69ba5239f8bdd37bfac95cd02704d65 /tensorflow/java
parentb24037513f12a5812a21b7ea92ff904ee9ea6cd8 (diff)
Add scope name to TF_AddGradients
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java17
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java4
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Scope.java13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java4
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc14
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h6
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java19
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java17
8 files changed, 76 insertions, 18 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 7d19696749..f2bd3e99a5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -153,12 +153,13 @@ public final class Graph implements AutoCloseable {
* If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
* shapes in {@code y}.
*
+ * @param scopeName name of the subscope into which gradients operations are added. If null, defaults to "gradients".
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
- public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ public Output<?>[] addGradients(String scopeName, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
@@ -185,12 +186,12 @@ public final class Graph implements AutoCloseable {
dxIndices[i] = dx[i].index();
}
}
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
- // of the gradient operations while the second holds the index of their output
- // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native
+ // handles of the gradient operations while the second holds the index of their output e.g. given
+ // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(ref.nativeHandle(), scopeName, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -209,14 +210,14 @@ public final class Graph implements AutoCloseable {
* i.e., {@code dy/dx_1, dy/dx_2...}
* <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
- * a single output and {@code dx} is null.
+ * a single output, {@code dx} is null and {@code scopeName} is null.
*
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
- return addGradients(new Output<?>[]{y}, x, null);
+ return addGradients(null, new Output<?>[]{y}, x, null);
}
private final Object nativeHandleLock = new Object();
@@ -330,7 +331,7 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
+ private static native long[] addGradients(long handle, String scopeName, long[] inputHandles, int[] inputIndices,
long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
static {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
index 2e84cac1ac..92e05d2d6d 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/NameScope.java
@@ -56,6 +56,10 @@ final class NameScope {
String actualName = (opName != null) ? opName : name;
return fullyQualify(makeUnique(actualName));
}
+
+ String prefix() {
+ return opPrefix;
+ }
/**
* Create a new, root-level namescope.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
index 8de2eaeb79..d1ab44c3b2 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
@@ -154,6 +154,19 @@ public final class Scope {
public String makeOpName(String defaultName) {
return nameScope.makeOpName(defaultName);
}
+
+ /**
+ * The name prefix of this scope
+ * <p>
+ * This value is the combination of the name of this scope and all of its parents, seperated by a '/', e.g.
+ * <pre>{@code
+ * Scope scope = new Scope(graph);
+ * assertEquals(scope.withSubScope("sub1").withSubScope("sub2").prefix(), "sub1/sub2");
+ * }</pre>
+ */
+ public String prefix() {
+ return nameScope.prefix();
+ }
private Scope(Graph graph, NameScope nameScope) {
this.graph = graph;
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index f4671c8af9..d88dc3ba46 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -88,8 +88,8 @@ public class Gradients implements Op, Iterable<Operand<?>> {
}
}
}
- Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
- return new Gradients(Arrays.asList(gradOutputs));
+ Output<?>[] dy = scope.graph().addGradients(scope.prefix(), Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(dy));
}
/**
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index dac6a345e9..a9b2ef6494 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -135,7 +135,7 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
JNIEXPORT jlongArray JNICALL
Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
- jlongArray y_handles, jintArray y_indices,
+ jstring scope_name, jlongArray y_handles, jintArray y_indices,
jlongArray x_handles, jintArray x_indices,
jlongArray dx_handles, jintArray dx_indices) {
@@ -163,9 +163,17 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
}
if (env->ExceptionCheck()) return nullptr;
+ jboolean is_copy;
+ const char* cscope_name = nullptr;
+ if (scope_name != nullptr) {
+ cscope_name = env->GetStringUTFChars(scope_name, &is_copy);
+ }
TF_Status* status = TF_NewStatus();
- TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
-
+ TF_AddGradients(g, cscope_name, y.get(), ny, x.get(), nx, dx.get(), status,
+ dy.get());
+ if (scope_name != nullptr) {
+ env->ReleaseStringUTFChars(scope_name, cscope_name);
+ }
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return nullptr;
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index 4f87e8d5a7..e483bf953b 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
/*
* Class: org_tensorflow_Graph
* Method: name
- * Signature: (J[J[I[J[I[J[I)[J
+ * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J
*/
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
- jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
- jintArray);
+ jclass, jlong, jstring, jlongArray, jintArray, jlongArray, jintArray,
+ jlongArray, jintArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c2e52c22c6..c02336aebe 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -181,7 +181,7 @@ public class GraphTest {
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
- Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+ Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null);
assertNotNull(grad);
assertEquals(1, grad.length);
assertEquals(DataType.FLOAT, grad[0].dataType());
@@ -212,7 +212,7 @@ public class GraphTest {
assertEquals(1, grad0.length);
assertEquals(DataType.FLOAT, grad0[0].dataType());
- Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0]));
assertNotNull(grad1);
assertEquals(1, grad1.length);
assertEquals(DataType.FLOAT, grad1[0].dataType());
@@ -229,6 +229,21 @@ public class GraphTest {
}
}
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+
+ Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad0[0].op().name().startsWith("gradients/"));
+
+ Output<?>[] grad1 = g.addGradients("more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad1[0].op().name().startsWith("more_gradients/"));
+ }
+ }
+
private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
index 125de73554..2057007499 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/ScopeTest.java
@@ -17,10 +17,12 @@ package org.tensorflow.op;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.HashMap;
import java.util.Map;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -181,6 +183,21 @@ public class ScopeTest {
assertEquals(21704, result.intValue());
}
}
+
+ @Test
+ public void prefix() {
+ try (Graph g = new Graph()) {
+ Scope s = new Scope(g);
+ assertNotNull(s.prefix());
+ assertTrue(s.prefix().isEmpty());
+
+ Scope sub1 = s.withSubScope("sub1");
+ assertEquals("sub1", sub1.prefix());
+
+ Scope sub2 = sub1.withSubScope("sub2");
+ assertEquals("sub1/sub2", sub2.prefix());
+ }
+ }
// "handwritten" sample operator classes
private static final class Const<T> {