aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-07-14 00:15:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-14 00:18:56 -0700
commit8aa4179ffae2d0a3724a70bea32ce35e2d88751a (patch)
treead946f3b966e91c1e996d22e33be2edc27bff72f /tensorflow/java
parent88b656acd480f6956894e3bb8c8f0c52fe033bc4 (diff)
[Java]: Support ConfigProto and RunOptions when loading SavedModels.
Fixes #18143 Fixes #20769 (Similar to #18716 by @raintung) PiperOrigin-RevId: 204575441
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java73
-rw-r--r--tensorflow/java/src/main/native/saved_model_bundle_jni.cc15
-rw-r--r--tensorflow/java/src/main/native/saved_model_bundle_jni.h4
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java54
4 files changed, 141 insertions, 5 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
index c8b9126f03..49594e6b47 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -25,18 +25,86 @@ package org.tensorflow;
* protocol buffer</a>).
*/
public class SavedModelBundle implements AutoCloseable {
+ /** Options for loading a SavedModel. */
+ public static final class Loader {
+ /** Load a <code>SavedModelBundle</code> with the configured options. */
+ public SavedModelBundle load() {
+ return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
+ }
+
+ /**
+ * Sets options to use when executing model initialization operations.
+ *
+ * @param options Serialized <a
+ * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
+ * protocol buffer</a>.
+ */
+ public Loader withRunOptions(byte[] options) {
+ this.runOptions = options;
+ return this;
+ }
+
+ /**
+ * Set configuration of the <code>Session</code> object created when loading the model.
+ *
+ * @param configProto Serialized <a
+ * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto
+ * protocol buffer</a>.
+ */
+ public Loader withConfigProto(byte[] configProto) {
+ this.configProto = configProto;
+ return this;
+ }
+
+ /**
+ * Sets the set of tags that identify the specific graph in the saved model to load.
+ *
+ * @param tags the tags identifying the specific MetaGraphDef to load.
+ */
+ public Loader withTags(String... tags) {
+ this.tags = tags;
+ return this;
+ }
+
+ private Loader(String exportDir) {
+ this.exportDir = exportDir;
+ }
+
+ private String exportDir = null;
+ private String[] tags = null;
+ private byte[] configProto = null;
+ private byte[] runOptions = null;
+ }
/**
* Load a saved model from an export directory. The model that is being loaded should be created
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
* API</a>.
*
+ * <p>This method is a shorthand for:
+ *
+ * <pre>{@code
+ * SavedModelBundler.loader().withTags(tags).load();
+ * }</pre>
+ *
* @param exportDir the directory path containing a saved model.
* @param tags the tags identifying the specific metagraphdef to load.
* @return a bundle containing the graph and associated session.
*/
public static SavedModelBundle load(String exportDir, String... tags) {
- return load(exportDir, tags, null);
+ return loader(exportDir).withTags(tags).load();
+ }
+
+ /**
+ * Load a saved model.
+ *
+ * <p/>Returns a <code>Loader</code> object that can set configuration options before actually
+ * loading the model,
+ *
+ * @param exportDir the directory path containing a saved model.
+ */
+ public static Loader loader(String exportDir) {
+ return new Loader(exportDir);
}
/**
@@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable {
return new SavedModelBundle(graph, session, metaGraphDef);
}
- private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions);
+ private static native SavedModelBundle load(
+ String exportDir, String[] tags, byte[] config, byte[] runOptions);
static {
TensorFlow.init();
diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
index de6382a79c..68999fb2da 100644
--- a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
+++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
@@ -22,12 +22,25 @@ limitations under the License.
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
- jbyteArray run_options) {
+ jbyteArray config, jbyteArray run_options) {
TF_Status* status = TF_NewStatus();
jobject bundle = nullptr;
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions* opts = TF_NewSessionOptions();
+ if (config != nullptr) {
+ size_t sz = env->GetArrayLength(config);
+ if (sz > 0) {
+ jbyte* config_data = env->GetByteArrayElements(config, nullptr);
+ TF_SetConfig(opts, static_cast<void*>(config_data), sz, status);
+ env->ReleaseByteArrayElements(config, config_data, JNI_ABORT);
+ if (!throwExceptionIfNotOK(env, status)) {
+ TF_DeleteSessionOptions(opts);
+ TF_DeleteStatus(status);
+ return nullptr;
+ }
+ }
+ }
TF_Buffer* crun_options = nullptr;
if (run_options != nullptr) {
size_t sz = env->GetArrayLength(run_options);
diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.h b/tensorflow/java/src/main/native/saved_model_bundle_jni.h
index 6cce6a81bd..a4b05d0409 100644
--- a/tensorflow/java/src/main/native/saved_model_bundle_jni.h
+++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.h
@@ -26,10 +26,10 @@ extern "C" {
* Class: org_tensorflow_SavedModelBundle
* Method: load
* Signature:
- * (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle;
+ * (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle;
*/
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
- JNIEnv *, jclass, jstring, jobjectArray, jbyteArray);
+ JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
index b063b6f1cd..7d936867a7 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
@@ -50,4 +50,58 @@ public class SavedModelBundleTest {
assertTrue(e.getMessage().contains("Could not find SavedModel"));
}
}
+
+ @Test
+ public void loader() {
+ try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH)
+ .withTags("serve")
+ .withConfigProto(sillyConfigProto())
+ .withRunOptions(sillyRunOptions())
+ .load()) {
+ assertNotNull(bundle.session());
+ assertNotNull(bundle.graph());
+ assertNotNull(bundle.metaGraphDef());
+ }
+ }
+
+ private static byte[] sillyRunOptions() {
+ // Ideally this would use the generated Java sources for protocol buffers
+ // and end up with something like the snippet below. However, generating
+ // the Java files for the .proto files in tensorflow/core:protos_all is
+ // a bit cumbersome in bazel until the proto_library rule is setup.
+ //
+ // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
+ //
+ // For this test, for now, the use of specific bytes suffices.
+ return new byte[] {0x08, 0x03};
+ /*
+ return org.tensorflow.framework.RunOptions.newBuilder()
+ .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
+ .build()
+ .toByteArray();
+ */
+ }
+
+ public static byte[] sillyConfigProto() {
+ // Ideally this would use the generated Java sources for protocol buffers
+ // and end up with something like the snippet below. However, generating
+ // the Java files for the .proto files in tensorflow/core:protos_all is
+ // a bit cumbersome in bazel until the proto_library rule is setup.
+ //
+ // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
+ //
+ // For this test, for now, the use of specific bytes suffices.
+ return new byte[] {0x10, 0x01, 0x28, 0x01};
+ /*
+ return org.tensorflow.framework.ConfigProto.newBuilder()
+ .setInterOpParallelismThreads(1)
+ .setIntraOpParallelismThreads(1)
+ .build()
+ .toByteArray();
+ */
+ }
}