From 7ebdc9834bbc583bcc42551b660c8ed256ea7416 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Sun, 8 Jul 2018 00:21:45 -0400 Subject: 1st code review: rename 'scope_name' to 'prefix', etc. --- tensorflow/java/src/main/java/org/tensorflow/Graph.java | 11 ++++++----- .../java/src/main/java/org/tensorflow/op/Scope.java | 2 +- tensorflow/java/src/main/native/graph_jni.cc | 17 ++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) (limited to 'tensorflow/java') diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index f2bd3e99a5..353092701b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -153,13 +153,14 @@ public final class Graph implements AutoCloseable { * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all * shapes in {@code y}. * - * @param scopeName name of the subscope into which gradients operations are added. If null, defaults to "gradients". + * @param prefix string prefix applied to names of nodes added to the graph to compute gradients. + * If null, defaults to "gradients". * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ - public Output[] addGradients(String scopeName, Output[] y, Output[] x, Output[] dx) { + public Output[] addGradients(String prefix, Output[] y, Output[] x, Output[] dx) { Output[] dy = new Output[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; @@ -191,7 +192,7 @@ public final class Graph implements AutoCloseable { // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] long[] dyHandlesAndIndices = - addGradients(ref.nativeHandle(), scopeName, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + addGradients(ref.nativeHandle(), prefix, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length @@ -210,7 +211,7 @@ public final class Graph implements AutoCloseable { * i.e., {@code dy/dx_1, dy/dx_2...} *

* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is - * a single output, {@code dx} is null and {@code scopeName} is null. + * a single output, {@code dx} is null and {@code prefix} is null. * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed @@ -331,7 +332,7 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); - private static native long[] addGradients(long handle, String scopeName, long[] inputHandles, int[] inputIndices, + private static native long[] addGradients(long handle, String prefix, long[] inputHandles, int[] inputIndices, long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); static { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 51a6ce8318..cf0b3d98c1 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -156,7 +156,7 @@ public final class Scope { } /** - * The name prefix of this scope + * The name prefix of this scope. *

* This value is the combination of the name of this scope and all of its parents, seperated by a '/', e.g. *

{@code
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index a9b2ef6494..1bbda52641 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -135,7 +135,7 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
 
 JNIEXPORT jlongArray JNICALL
 Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
-    jstring scope_name, jlongArray y_handles, jintArray y_indices,
+    jstring prefix, jlongArray y_handles, jintArray y_indices,
     jlongArray x_handles, jintArray x_indices,
     jlongArray dx_handles, jintArray dx_indices) {
 
@@ -163,16 +163,15 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
   }
   if (env->ExceptionCheck()) return nullptr;
 
-  jboolean is_copy;
-  const char* cscope_name = nullptr;
-  if (scope_name != nullptr) {
-    cscope_name = env->GetStringUTFChars(scope_name, &is_copy);
+  const char* cprefix = nullptr;
+  if (prefix != nullptr) {
+    cprefix = env->GetStringUTFChars(prefix, nullptr);
   }
   TF_Status* status = TF_NewStatus();
-  TF_AddGradients(g, cscope_name, y.get(), ny, x.get(), nx, dx.get(), status,
-      dy.get());
-  if (scope_name != nullptr) {
-    env->ReleaseStringUTFChars(scope_name, cscope_name);
+  TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
+                            status, dy.get());
+  if (prefix != nullptr) {
+    env->ReleaseStringUTFChars(prefix, cprefix);
   }
   if (!throwExceptionIfNotOK(env, status)) {
     TF_DeleteStatus(status);
-- 
cgit v1.2.3