diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java | 73 |
1 files changed, 71 insertions, 2 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java index c8b9126f03..49594e6b47 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java @@ -25,18 +25,86 @@ package org.tensorflow; * protocol buffer</a>). */ public class SavedModelBundle implements AutoCloseable { + /** Options for loading a SavedModel. */ + public static final class Loader { + /** Load a <code>SavedModelBundle</code> with the configured options. */ + public SavedModelBundle load() { + return SavedModelBundle.load(exportDir, tags, configProto, runOptions); + } + + /** + * Sets options to use when executing model initialization operations. + * + * @param options Serialized <a + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions + * protocol buffer</a>. + */ + public Loader withRunOptions(byte[] options) { + this.runOptions = options; + return this; + } + + /** + * Set configuration of the <code>Session</code> object created when loading the model. + * + * @param configProto Serialized <a + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto + * protocol buffer</a>. + */ + public Loader withConfigProto(byte[] configProto) { + this.configProto = configProto; + return this; + } + + /** + * Sets the set of tags that identify the specific graph in the saved model to load. + * + * @param tags the tags identifying the specific MetaGraphDef to load. + */ + public Loader withTags(String... tags) { + this.tags = tags; + return this; + } + + private Loader(String exportDir) { + this.exportDir = exportDir; + } + + private String exportDir = null; + private String[] tags = null; + private byte[] configProto = null; + private byte[] runOptions = null; + } /** * Load a saved model from an export directory. The model that is being loaded should be created * using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model * API</a>. * + * <p>This method is a shorthand for: + * + * <pre>{@code + * SavedModelBundler.loader().withTags(tags).load(); + * }</pre> + * * @param exportDir the directory path containing a saved model. * @param tags the tags identifying the specific metagraphdef to load. * @return a bundle containing the graph and associated session. */ public static SavedModelBundle load(String exportDir, String... tags) { - return load(exportDir, tags, null); + return loader(exportDir).withTags(tags).load(); + } + + /** + * Load a saved model. + * + * <p/>Returns a <code>Loader</code> object that can set configuration options before actually + * loading the model, + * + * @param exportDir the directory path containing a saved model. + */ + public static Loader loader(String exportDir) { + return new Loader(exportDir); } /** @@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable { return new SavedModelBundle(graph, session, metaGraphDef); } - private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions); + private static native SavedModelBundle load( + String exportDir, String[] tags, byte[] config, byte[] runOptions); static { TensorFlow.init(); |