aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java101
1 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
new file mode 100644
index 0000000000..c3938fe23f
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -0,0 +1,101 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+/**
+ * SavedModelBundle represents a model loaded from storage.
+ *
+ * <p>The model consists of a description of the computation (a {@link Graph}), a {@link Session}
+ * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage,
+ * and a description of the model (a serialized representation of a <a
+ * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef
+ * protocol buffer</a>).
+ */
+public class SavedModelBundle implements AutoCloseable {
+
+ /**
+ * Load a saved model from an export directory.
+ *
+ * @param exportDir the directory path containing a saved model.
+ * @param tags the tags identifying the specific metagraphdef to load.
+ * @return a bundle containing the graph and associated session.
+ */
+ public static SavedModelBundle load(String exportDir, String... tags) {
+ return load(exportDir, tags, null);
+ }
+
+ /**
+ * Returns the serialized <a
+ * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto">MetaGraphDef
+ * protocol buffer</a> associated with the saved model.
+ */
+ public byte[] metaGraphDef() {
+ return metaGraphDef;
+ }
+
+ /** Returns the graph that describes the computation performed by the model. */
+ public Graph graph() {
+ return graph;
+ }
+
+ /**
+ * Returns the {@link Session} with which to perform computation using the model.
+ *
+ * @return the initialized session
+ */
+ public Session session() {
+ return session;
+ }
+
+ /**
+ * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model
+ * bundle.
+ */
+ @Override
+ public void close() {
+ session.close();
+ graph.close();
+ }
+
+ private final Graph graph;
+ private final Session session;
+ private final byte[] metaGraphDef;
+
+ private SavedModelBundle(Graph graph, Session session, byte[] metaGraphDef) {
+ this.graph = graph;
+ this.session = session;
+ this.metaGraphDef = metaGraphDef;
+ }
+
+ /**
+ * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session
+ * object, plus the serialized MetaGraphDef.
+ *
+ * <p>Invoked from the native load method. Takes ownership of the handles.
+ */
+ private static SavedModelBundle fromHandle(
+ long graphHandle, long sessionHandle, byte[] metaGraphDef) {
+ Graph graph = new Graph(graphHandle);
+ Session session = new Session(graph, sessionHandle);
+ return new SavedModelBundle(graph, session, metaGraphDef);
+ }
+
+ private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions);
+
+ static {
+ TensorFlow.init();
+ }
+}