diff options
Diffstat (limited to 'tensorflow/java/src/main/native/graph_jni.cc')
-rw-r--r-- | tensorflow/java/src/main/native/graph_jni.cc | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index dac6a345e9..f1744d8769 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { return ret; } -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, - jlongArray y_handles, jintArray y_indices, - jlongArray x_handles, jintArray x_indices, - jlongArray dx_handles, jintArray dx_indices) { - +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv* env, jclass clazz, jlong handle, jstring prefix, + jlongArray y_handles, jintArray y_indices, jlongArray x_handles, + jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { TF_Graph* g = requireHandle(env, handle); if (g == nullptr) return nullptr; @@ -163,9 +161,16 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, } if (env->ExceptionCheck()) return nullptr; + const char* cprefix = nullptr; + if (prefix != nullptr) { + cprefix = env->GetStringUTFChars(prefix, nullptr); + } TF_Status* status = TF_NewStatus(); - TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); - + TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(), + status, dy.get()); + if (prefix != nullptr) { + env->ReleaseStringUTFChars(prefix, cprefix); + } if (!throwExceptionIfNotOK(env, status)) { TF_DeleteStatus(status); return nullptr; |