diff options
Diffstat (limited to 'tensorflow/java/src/main/native/graph_jni.cc')
-rw-r--r-- | tensorflow/java/src/main/native/graph_jni.cc | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index dac6a345e9..a9b2ef6494 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -135,7 +135,7 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, - jlongArray y_handles, jintArray y_indices, + jstring scope_name, jlongArray y_handles, jintArray y_indices, jlongArray x_handles, jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { @@ -163,9 +163,17 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, } if (env->ExceptionCheck()) return nullptr; + jboolean is_copy; + const char* cscope_name = nullptr; + if (scope_name != nullptr) { + cscope_name = env->GetStringUTFChars(scope_name, &is_copy); + } TF_Status* status = TF_NewStatus(); - TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); - + TF_AddGradients(g, cscope_name, y.get(), ny, x.get(), nx, dx.get(), status, + dy.get()); + if (scope_name != nullptr) { + env->ReleaseStringUTFChars(scope_name, cscope_name); + } if (!throwExceptionIfNotOK(env, status)) { TF_DeleteStatus(status); return nullptr; |