aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-10-04 13:33:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 13:41:31 -0700
commit083bd5dde5e6845a6f5e3b83ea2e074d7b28d61f (patch)
treeaeb9d0a9ff17d9a4d014816eb5726ebc17ea1caf
parent2fe6cf285d2bf4222ea09f9e929e538b64bc376b (diff)
Java: Add support for loading op libraries dynamically.
This change adds the equivalent of tf.load_op_library in Python to Java. (https://github.com/tensorflow/tensorflow/commit/5c7f9e316d8c7735308a217310350d416d7498cc was required to make this possible) Though, TensorFlow.loadLibrary() is likely to fail on Windows as symbols required by custom op libraries (those exported by the tensorflow_framework library) are not exported by the monolithic JNI library yet. This should help with #10454 and #13476 PiperOrigin-RevId: 171054707
-rw-r--r--tensorflow/java/BUILD9
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java30
-rw-r--r--tensorflow/java/src/main/native/tensorflow_jni.cc35
-rw-r--r--tensorflow/java/src/main/native/tensorflow_jni.h30
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java23
-rw-r--r--tensorflow/java/src/test/native/my_test_op.cc21
6 files changed, 145 insertions, 3 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 9de79af7d2..a380bc2c71 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -10,8 +10,9 @@ load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar")
load(
"//tensorflow:tensorflow.bzl",
"tf_binary_additional_srcs",
- "tf_copts",
"tf_cc_binary",
+ "tf_copts",
+ "tf_custom_op_library",
"tf_java_test",
)
@@ -180,10 +181,16 @@ tf_java_test(
],
)
+tf_custom_op_library(
+ name = "my_test_op.so",
+ srcs = ["src/test/native/my_test_op.cc"],
+)
+
tf_java_test(
name = "TensorFlowTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"],
+ data = [":my_test_op.so"],
javacopts = JAVACOPTS,
test_class = "org.tensorflow.TensorFlowTest",
deps = [
diff --git a/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java
index c21214b763..c90655f25d 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java
@@ -29,6 +29,36 @@ public final class TensorFlow {
*/
public static native byte[] registeredOpList();
+ /**
+ * Load the dynamic library in filename and register the operations and kernels present in that
+ * library.
+ *
+ * @param filename Path of the dynamic library containing operations and kernels to load.
+ * @return Serialized bytes of the <a
+ * href="https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto">OpList</a>
+ * protocol buffer message defining the operations defined in the library.
+ * @throws UnsatisfiedLinkError if filename cannot be loaded.
+ */
+ public static byte[] loadLibrary(String filename) {
+ long h = 0;
+ try {
+ h = libraryLoad(filename);
+ } catch (RuntimeException e) {
+ throw new UnsatisfiedLinkError(e.getMessage());
+ }
+ try {
+ return libraryOpList(h);
+ } finally {
+ libraryDelete(h);
+ }
+ }
+
+ private static native long libraryLoad(String filename);
+
+ private static native void libraryDelete(long handle);
+
+ private static native byte[] libraryOpList(long handle);
+
private TensorFlow() {}
/** Load the TensorFlow runtime C library. */
diff --git a/tensorflow/java/src/main/native/tensorflow_jni.cc b/tensorflow/java/src/main/native/tensorflow_jni.cc
index c553582e38..946ab502d1 100644
--- a/tensorflow/java/src/main/native/tensorflow_jni.cc
+++ b/tensorflow/java/src/main/native/tensorflow_jni.cc
@@ -14,7 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/java/src/main/native/tensorflow_jni.h"
+
+#include <limits>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/exception_jni.h"
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv* env,
jclass clazz) {
@@ -30,3 +33,35 @@ Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv* env, jclass clazz) {
TF_DeleteBuffer(buf);
return ret;
}
+
+JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad(
+ JNIEnv* env, jclass clazz, jstring filename) {
+ TF_Status* status = TF_NewStatus();
+ const char* cname = env->GetStringUTFChars(filename, nullptr);
+ TF_Library* h = TF_LoadLibrary(cname, status);
+ throwExceptionIfNotOK(env, status);
+ env->ReleaseStringUTFChars(filename, cname);
+ TF_DeleteStatus(status);
+ return reinterpret_cast<jlong>(h);
+}
+
+JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete(
+ JNIEnv* env, jclass clazz, jlong handle) {
+ if (handle != 0) {
+ TF_DeleteLibraryHandle(reinterpret_cast<TF_Library*>(handle));
+ }
+}
+
+JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_TensorFlow_libraryOpList(
+ JNIEnv* env, jclass clazz, jlong handle) {
+ TF_Buffer buf = TF_GetOpList(reinterpret_cast<TF_Library*>(handle));
+ if (buf.length > std::numeric_limits<jint>::max()) {
+ throwException(env, kIndexOutOfBoundsException,
+ "Serialized OpList is too large for a byte[] array");
+ return nullptr;
+ }
+ auto ret_len = static_cast<jint>(buf.length);
+ jbyteArray ret = env->NewByteArray(ret_len);
+ env->SetByteArrayRegion(ret, 0, ret_len, static_cast<const jbyte*>(buf.data));
+ return ret;
+}
diff --git a/tensorflow/java/src/main/native/tensorflow_jni.h b/tensorflow/java/src/main/native/tensorflow_jni.h
index ecd9b15828..c0c9322020 100644
--- a/tensorflow/java/src/main/native/tensorflow_jni.h
+++ b/tensorflow/java/src/main/native/tensorflow_jni.h
@@ -27,7 +27,7 @@ extern "C" {
* Method: version
* Signature: ()Ljava/lang/String;
*/
-JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv*,
+JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv *,
jclass);
/*
@@ -36,7 +36,33 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv*,
* Signature: ()[B
*/
JNIEXPORT jbyteArray JNICALL
-Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv*, jclass);
+Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv *, jclass);
+
+/*
+ * Class: org_tensorflow_TensorFlow
+ * Method: libraryLoad
+ * Signature: (Ljava/lang/String;)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad(JNIEnv *,
+ jclass,
+ jstring);
+
+/*
+ * Class: org_tensorflow_TensorFlow
+ * Method: libraryDelete
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete(JNIEnv *,
+ jclass,
+ jlong);
+
+/*
+ * Class: org_tensorflow_TensorFlow
+ * Method: libraryOpList
+ * Signature: (J)[B
+ */
+JNIEXPORT jbyteArray JNICALL
+Java_org_tensorflow_TensorFlow_libraryOpList(JNIEnv *, jclass, jlong);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java
index a31ea900d1..b1fa3f0d7e 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -36,4 +37,26 @@ public class TensorFlowTest {
// was not sorted out. Revisit? Till then, at least exercise the code.
assertTrue(TensorFlow.registeredOpList().length > 0);
}
+
+ @Test
+ public void loadLibrary() {
+ // TODO(ashankar): This tell will fail when built with --config=monolithic.
+ // Figure out how we can ignore the test in that case.
+ try (Graph g = new Graph()) {
+ // Build a graph with an unrecognized operation.
+ try {
+ g.opBuilder("MyTest", "MyTest").build();
+ fail("should not be able to construct graphs with unregistered ops");
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+
+ // Load the library containing the operation.
+ byte[] opList = TensorFlow.loadLibrary("tensorflow/java/my_test_op.so");
+ assertTrue(opList.length > 0);
+
+ // Now graph building should succeed.
+ g.opBuilder("MyTest", "MyTest").build();
+ }
+ }
}
diff --git a/tensorflow/java/src/test/native/my_test_op.cc b/tensorflow/java/src/test/native/my_test_op.cc
new file mode 100644
index 0000000000..eb755901ed
--- /dev/null
+++ b/tensorflow/java/src/test/native/my_test_op.cc
@@ -0,0 +1,21 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+REGISTER_OP("MyTest")
+ .Doc("Custom operation for testing.")
+ .SetShapeFn(tensorflow::shape_inference::UnknownShape);