From 8aa4179ffae2d0a3724a70bea32ce35e2d88751a Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Sat, 14 Jul 2018 00:15:49 -0700 Subject: [Java]: Support ConfigProto and RunOptions when loading SavedModels. Fixes #18143 Fixes #20769 (Similar to #18716 by @raintung) PiperOrigin-RevId: 204575441 --- .../main/java/org/tensorflow/SavedModelBundle.java | 73 +++++++++++++++++++++- .../java/src/main/native/saved_model_bundle_jni.cc | 15 ++++- .../java/src/main/native/saved_model_bundle_jni.h | 4 +- .../java/org/tensorflow/SavedModelBundleTest.java | 54 ++++++++++++++++ 4 files changed, 141 insertions(+), 5 deletions(-) (limited to 'tensorflow/java') diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java index c8b9126f03..49594e6b47 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java @@ -25,18 +25,86 @@ package org.tensorflow; * protocol buffer). */ public class SavedModelBundle implements AutoCloseable { + /** Options for loading a SavedModel. */ + public static final class Loader { + /** Load a SavedModelBundle with the configured options. */ + public SavedModelBundle load() { + return SavedModelBundle.load(exportDir, tags, configProto, runOptions); + } + + /** + * Sets options to use when executing model initialization operations. + * + * @param options Serialized RunOptions + * protocol buffer. + */ + public Loader withRunOptions(byte[] options) { + this.runOptions = options; + return this; + } + + /** + * Set configuration of the Session object created when loading the model. + * + * @param configProto Serialized ConfigProto + * protocol buffer. + */ + public Loader withConfigProto(byte[] configProto) { + this.configProto = configProto; + return this; + } + + /** + * Sets the set of tags that identify the specific graph in the saved model to load. + * + * @param tags the tags identifying the specific MetaGraphDef to load. + */ + public Loader withTags(String... tags) { + this.tags = tags; + return this; + } + + private Loader(String exportDir) { + this.exportDir = exportDir; + } + + private String exportDir = null; + private String[] tags = null; + private byte[] configProto = null; + private byte[] runOptions = null; + } /** * Load a saved model from an export directory. The model that is being loaded should be created * using the Saved Model * API. * + *

This method is a shorthand for: + * + *

{@code
+   * SavedModelBundler.loader().withTags(tags).load();
+   * }
+ * * @param exportDir the directory path containing a saved model. * @param tags the tags identifying the specific metagraphdef to load. * @return a bundle containing the graph and associated session. */ public static SavedModelBundle load(String exportDir, String... tags) { - return load(exportDir, tags, null); + return loader(exportDir).withTags(tags).load(); + } + + /** + * Load a saved model. + * + *

Returns a Loader object that can set configuration options before actually + * loading the model, + * + * @param exportDir the directory path containing a saved model. + */ + public static Loader loader(String exportDir) { + return new Loader(exportDir); } /** @@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable { return new SavedModelBundle(graph, session, metaGraphDef); } - private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions); + private static native SavedModelBundle load( + String exportDir, String[] tags, byte[] config, byte[] runOptions); static { TensorFlow.init(); diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc index de6382a79c..68999fb2da 100644 --- a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc +++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc @@ -22,12 +22,25 @@ limitations under the License. JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load( JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags, - jbyteArray run_options) { + jbyteArray config, jbyteArray run_options) { TF_Status* status = TF_NewStatus(); jobject bundle = nullptr; // allocate parameters for TF_LoadSessionFromSavedModel TF_SessionOptions* opts = TF_NewSessionOptions(); + if (config != nullptr) { + size_t sz = env->GetArrayLength(config); + if (sz > 0) { + jbyte* config_data = env->GetByteArrayElements(config, nullptr); + TF_SetConfig(opts, static_cast(config_data), sz, status); + env->ReleaseByteArrayElements(config, config_data, JNI_ABORT); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteSessionOptions(opts); + TF_DeleteStatus(status); + return nullptr; + } + } + } TF_Buffer* crun_options = nullptr; if (run_options != nullptr) { size_t sz = env->GetArrayLength(run_options); diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.h b/tensorflow/java/src/main/native/saved_model_bundle_jni.h index 6cce6a81bd..a4b05d0409 100644 --- a/tensorflow/java/src/main/native/saved_model_bundle_jni.h +++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.h @@ -26,10 +26,10 @@ extern "C" { * Class: org_tensorflow_SavedModelBundle * Method: load * Signature: - * (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle; + * (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle; */ JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load( - JNIEnv *, jclass, jstring, jobjectArray, jbyteArray); + JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java index b063b6f1cd..7d936867a7 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -50,4 +50,58 @@ public class SavedModelBundleTest { assertTrue(e.getMessage().contains("Could not find SavedModel")); } } + + @Test + public void loader() { + try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH) + .withTags("serve") + .withConfigProto(sillyConfigProto()) + .withRunOptions(sillyRunOptions()) + .load()) { + assertNotNull(bundle.session()); + assertNotNull(bundle.graph()); + assertNotNull(bundle.metaGraphDef()); + } + } + + private static byte[] sillyRunOptions() { + // Ideally this would use the generated Java sources for protocol buffers + // and end up with something like the snippet below. However, generating + // the Java files for the .proto files in tensorflow/core:protos_all is + // a bit cumbersome in bazel until the proto_library rule is setup. + // + // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 + // + // For this test, for now, the use of specific bytes suffices. + return new byte[] {0x08, 0x03}; + /* + return org.tensorflow.framework.RunOptions.newBuilder() + .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) + .build() + .toByteArray(); + */ + } + + public static byte[] sillyConfigProto() { + // Ideally this would use the generated Java sources for protocol buffers + // and end up with something like the snippet below. However, generating + // the Java files for the .proto files in tensorflow/core:protos_all is + // a bit cumbersome in bazel until the proto_library rule is setup. + // + // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 + // + // For this test, for now, the use of specific bytes suffices. + return new byte[] {0x10, 0x01, 0x28, 0x01}; + /* + return org.tensorflow.framework.ConfigProto.newBuilder() + .setInterOpParallelismThreads(1) + .setIntraOpParallelismThreads(1) + .build() + .toByteArray(); + */ + } } -- cgit v1.2.3