aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/session.go
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2017-02-08 11:55:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-08 12:08:23 -0800
commit085102c2e2947d76056b6363da96c55ecd838e6c (patch)
tree4fb72284384f7e985503b959a33e9322f6b7a13c /tensorflow/go/session.go
parent42dc6764a0518ec0937bb9d94949c2af29508371 (diff)
SavedModel support in Go.
Change: 146938337
Diffstat (limited to 'tensorflow/go/session.go')
-rw-r--r--tensorflow/go/session.go38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
index dd629441ef..c29b6e0b76 100644
--- a/tensorflow/go/session.go
+++ b/tensorflow/go/session.go
@@ -59,6 +59,44 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
return s, nil
}
+// LoadSavedModel creates a new Session from a model previously exported to a
+// directory on disk.
+//
+// Exported models contain a set of graphs and variable values. Tags in the
+// model identify a single graph. LoadSessionFromSavedModel initializes a
+// session with the identified graph and with variables initialized to saved
+// values.
+//
+// The tensorflow package currently does not have the ability to export a model
+// to a directory from Go. This function thus currently targets loading models
+// exported in other languages, such as using tf.saved_model.builder in Python.
+// See:
+// https://www.tensorflow.org/code/tensorflow/python/saved_model/
+func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*Session, *Graph, error) {
+ status := newStatus()
+ cOpt := options.c()
+ cExportDir := C.CString(exportDir)
+ cTags := make([]*C.char, len(tags))
+ for i := range tags {
+ cTags[i] = C.CString(tags[i])
+ }
+ graph := NewGraph()
+ // TODO(jhseu): Add support for run_options and meta_graph_def.
+ cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c)
+ for i := range cTags {
+ C.free(unsafe.Pointer(cTags[i]))
+ }
+ C.free(unsafe.Pointer(cExportDir))
+ C.TF_DeleteSessionOptions(cOpt)
+
+ if err := status.Err(); err != nil {
+ return nil, nil, err
+ }
+ s := &Session{c: cSess}
+ runtime.SetFinalizer(s, func(s *Session) { s.Close() })
+ return s, graph, nil
+}
+
// Run the graph with the associated session starting with the supplied inputs.
// inputs and outputs may be set to nil. Runs, but does not return Tensors
// for operations specified in targets.