diff options
author | Asim Shankar <ashankar@google.com> | 2017-10-04 13:33:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-04 13:41:31 -0700 |
commit | 083bd5dde5e6845a6f5e3b83ea2e074d7b28d61f (patch) | |
tree | aeb9d0a9ff17d9a4d014816eb5726ebc17ea1caf | |
parent | 2fe6cf285d2bf4222ea09f9e929e538b64bc376b (diff) |
Java: Add support for loading op libraries dynamically.
This change adds the equivalent of tf.load_op_library in Python to Java.
(https://github.com/tensorflow/tensorflow/commit/5c7f9e316d8c7735308a217310350d416d7498cc
was required to make this possible)
Though, TensorFlow.loadLibrary() is likely to fail on Windows as symbols
required by custom op libraries (those exported by the tensorflow_framework library)
are not exported by the monolithic JNI library yet.
This should help with #10454 and #13476
PiperOrigin-RevId: 171054707
-rw-r--r-- | tensorflow/java/BUILD | 9 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java | 30 | ||||
-rw-r--r-- | tensorflow/java/src/main/native/tensorflow_jni.cc | 35 | ||||
-rw-r--r-- | tensorflow/java/src/main/native/tensorflow_jni.h | 30 | ||||
-rw-r--r-- | tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java | 23 | ||||
-rw-r--r-- | tensorflow/java/src/test/native/my_test_op.cc | 21 |
6 files changed, 145 insertions, 3 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 9de79af7d2..a380bc2c71 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -10,8 +10,9 @@ load(":src/gen/gen_ops.bzl", "tf_java_op_gen_srcjar") load( "//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs", - "tf_copts", "tf_cc_binary", + "tf_copts", + "tf_custom_op_library", "tf_java_test", ) @@ -180,10 +181,16 @@ tf_java_test( ], ) +tf_custom_op_library( + name = "my_test_op.so", + srcs = ["src/test/native/my_test_op.cc"], +) + tf_java_test( name = "TensorFlowTest", size = "small", srcs = ["src/test/java/org/tensorflow/TensorFlowTest.java"], + data = [":my_test_op.so"], javacopts = JAVACOPTS, test_class = "org.tensorflow.TensorFlowTest", deps = [ diff --git a/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java index c21214b763..c90655f25d 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java +++ b/tensorflow/java/src/main/java/org/tensorflow/TensorFlow.java @@ -29,6 +29,36 @@ public final class TensorFlow { */ public static native byte[] registeredOpList(); + /** + * Load the dynamic library in filename and register the operations and kernels present in that + * library. + * + * @param filename Path of the dynamic library containing operations and kernels to load. + * @return Serialized bytes of the <a + * href="https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto">OpList</a> + * protocol buffer message defining the operations defined in the library. + * @throws UnsatisfiedLinkError if filename cannot be loaded. + */ + public static byte[] loadLibrary(String filename) { + long h = 0; + try { + h = libraryLoad(filename); + } catch (RuntimeException e) { + throw new UnsatisfiedLinkError(e.getMessage()); + } + try { + return libraryOpList(h); + } finally { + libraryDelete(h); + } + } + + private static native long libraryLoad(String filename); + + private static native void libraryDelete(long handle); + + private static native byte[] libraryOpList(long handle); + private TensorFlow() {} /** Load the TensorFlow runtime C library. */ diff --git a/tensorflow/java/src/main/native/tensorflow_jni.cc b/tensorflow/java/src/main/native/tensorflow_jni.cc index c553582e38..946ab502d1 100644 --- a/tensorflow/java/src/main/native/tensorflow_jni.cc +++ b/tensorflow/java/src/main/native/tensorflow_jni.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/java/src/main/native/tensorflow_jni.h" + +#include <limits> #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/exception_jni.h" JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv* env, jclass clazz) { @@ -30,3 +33,35 @@ Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv* env, jclass clazz) { TF_DeleteBuffer(buf); return ret; } + +JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad( + JNIEnv* env, jclass clazz, jstring filename) { + TF_Status* status = TF_NewStatus(); + const char* cname = env->GetStringUTFChars(filename, nullptr); + TF_Library* h = TF_LoadLibrary(cname, status); + throwExceptionIfNotOK(env, status); + env->ReleaseStringUTFChars(filename, cname); + TF_DeleteStatus(status); + return reinterpret_cast<jlong>(h); +} + +JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete( + JNIEnv* env, jclass clazz, jlong handle) { + if (handle != 0) { + TF_DeleteLibraryHandle(reinterpret_cast<TF_Library*>(handle)); + } +} + +JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_TensorFlow_libraryOpList( + JNIEnv* env, jclass clazz, jlong handle) { + TF_Buffer buf = TF_GetOpList(reinterpret_cast<TF_Library*>(handle)); + if (buf.length > std::numeric_limits<jint>::max()) { + throwException(env, kIndexOutOfBoundsException, + "Serialized OpList is too large for a byte[] array"); + return nullptr; + } + auto ret_len = static_cast<jint>(buf.length); + jbyteArray ret = env->NewByteArray(ret_len); + env->SetByteArrayRegion(ret, 0, ret_len, static_cast<const jbyte*>(buf.data)); + return ret; +} diff --git a/tensorflow/java/src/main/native/tensorflow_jni.h b/tensorflow/java/src/main/native/tensorflow_jni.h index ecd9b15828..c0c9322020 100644 --- a/tensorflow/java/src/main/native/tensorflow_jni.h +++ b/tensorflow/java/src/main/native/tensorflow_jni.h @@ -27,7 +27,7 @@ extern "C" { * Method: version * Signature: ()Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv*, +JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv *, jclass); /* @@ -36,7 +36,33 @@ JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv*, * Signature: ()[B */ JNIEXPORT jbyteArray JNICALL -Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv*, jclass); +Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv *, jclass); + +/* + * Class: org_tensorflow_TensorFlow + * Method: libraryLoad + * Signature: (Ljava/lang/String;)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad(JNIEnv *, + jclass, + jstring); + +/* + * Class: org_tensorflow_TensorFlow + * Method: libraryDelete + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete(JNIEnv *, + jclass, + jlong); + +/* + * Class: org_tensorflow_TensorFlow + * Method: libraryOpList + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_tensorflow_TensorFlow_libraryOpList(JNIEnv *, jclass, jlong); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java index a31ea900d1..b1fa3f0d7e 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorFlowTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import org.junit.Test; import org.junit.runner.RunWith; @@ -36,4 +37,26 @@ public class TensorFlowTest { // was not sorted out. Revisit? Till then, at least exercise the code. assertTrue(TensorFlow.registeredOpList().length > 0); } + + @Test + public void loadLibrary() { + // TODO(ashankar): This tell will fail when built with --config=monolithic. + // Figure out how we can ignore the test in that case. + try (Graph g = new Graph()) { + // Build a graph with an unrecognized operation. + try { + g.opBuilder("MyTest", "MyTest").build(); + fail("should not be able to construct graphs with unregistered ops"); + } catch (IllegalArgumentException e) { + // expected exception + } + + // Load the library containing the operation. + byte[] opList = TensorFlow.loadLibrary("tensorflow/java/my_test_op.so"); + assertTrue(opList.length > 0); + + // Now graph building should succeed. + g.opBuilder("MyTest", "MyTest").build(); + } + } } diff --git a/tensorflow/java/src/test/native/my_test_op.cc b/tensorflow/java/src/test/native/my_test_op.cc new file mode 100644 index 0000000000..eb755901ed --- /dev/null +++ b/tensorflow/java/src/test/native/my_test_op.cc @@ -0,0 +1,21 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +REGISTER_OP("MyTest") + .Doc("Custom operation for testing.") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); |