aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2017-02-14 17:30:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-14 17:47:15 -0800
commitdd51f989b8ca738da8a04970857597ed68fa1a15 (patch)
tree6f7738c4d94c45b958f9d279ddaf28c71077ba57
parentb5d6636e5d60dc3094285d097264c3d1cf250a53 (diff)
Go: Add a SavedModel type
Change: 147543652
-rw-r--r--tensorflow/go/saved_model.go68
-rw-r--r--tensorflow/go/saved_model_test.go28
-rw-r--r--tensorflow/go/session.go38
-rw-r--r--tensorflow/go/session_test.go12
4 files changed, 96 insertions, 50 deletions
diff --git a/tensorflow/go/saved_model.go b/tensorflow/go/saved_model.go
new file mode 100644
index 0000000000..9bffa61765
--- /dev/null
+++ b/tensorflow/go/saved_model.go
@@ -0,0 +1,68 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package tensorflow
+
+import (
+ "runtime"
+ "unsafe"
+)
+
+// #include <stdlib.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+
+// SavedModel represents the contents of loaded SavedModel.
+// TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs.
+type SavedModel struct {
+ Session *Session
+ Graph *Graph
+}
+
+// LoadSavedModel creates a new SavedModel from a model previously
+// exported to a directory on disk.
+//
+// Exported models contain a set of graphs and, optionally, variable values.
+// Tags in the model identify a single graph. LoadSavedModel initializes a
+// session with the identified graph and with variables initialized to from the
+// checkpoints on disk.
+//
+// The tensorflow package currently does not have the ability to export a model
+// to a directory from Go. This function thus currently targets loading models
+// exported in other languages, such as using tf.saved_model.builder in Python.
+// See:
+// https://www.tensorflow.org/code/tensorflow/python/saved_model/
+func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error) {
+ status := newStatus()
+ cOpt := options.c()
+ cExportDir := C.CString(exportDir)
+ cTags := make([]*C.char, len(tags))
+ for i := range tags {
+ cTags[i] = C.CString(tags[i])
+ }
+ graph := NewGraph()
+ // TODO(jhseu): Add support for run_options and meta_graph_def.
+ cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c)
+ for i := range cTags {
+ C.free(unsafe.Pointer(cTags[i]))
+ }
+ C.free(unsafe.Pointer(cExportDir))
+ C.TF_DeleteSessionOptions(cOpt)
+
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ s := &Session{c: cSess}
+ runtime.SetFinalizer(s, func(s *Session) { s.Close() })
+ return &SavedModel{Session: s, Graph: graph}, nil
+}
diff --git a/tensorflow/go/saved_model_test.go b/tensorflow/go/saved_model_test.go
new file mode 100644
index 0000000000..685312ae8c
--- /dev/null
+++ b/tensorflow/go/saved_model_test.go
@@ -0,0 +1,28 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package tensorflow
+
+import "testing"
+
+func TestSavedModel(t *testing.T) {
+ bundle, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil)
+ if err != nil {
+ t.Fatalf("LoadSavedModel(): %v", err)
+ }
+ if op := bundle.Graph.Operation("y"); op == nil {
+ t.Fatalf("\"y\" not found in graph")
+ }
+ // TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
+ // more thorough test when the generated protobufs are available.
+}
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
index c29b6e0b76..dd629441ef 100644
--- a/tensorflow/go/session.go
+++ b/tensorflow/go/session.go
@@ -59,44 +59,6 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
return s, nil
}
-// LoadSavedModel creates a new Session from a model previously exported to a
-// directory on disk.
-//
-// Exported models contain a set of graphs and variable values. Tags in the
-// model identify a single graph. LoadSessionFromSavedModel initializes a
-// session with the identified graph and with variables initialized to saved
-// values.
-//
-// The tensorflow package currently does not have the ability to export a model
-// to a directory from Go. This function thus currently targets loading models
-// exported in other languages, such as using tf.saved_model.builder in Python.
-// See:
-// https://www.tensorflow.org/code/tensorflow/python/saved_model/
-func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*Session, *Graph, error) {
- status := newStatus()
- cOpt := options.c()
- cExportDir := C.CString(exportDir)
- cTags := make([]*C.char, len(tags))
- for i := range tags {
- cTags[i] = C.CString(tags[i])
- }
- graph := NewGraph()
- // TODO(jhseu): Add support for run_options and meta_graph_def.
- cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c)
- for i := range cTags {
- C.free(unsafe.Pointer(cTags[i]))
- }
- C.free(unsafe.Pointer(cExportDir))
- C.TF_DeleteSessionOptions(cOpt)
-
- if err := status.Err(); err != nil {
- return nil, nil, err
- }
- s := &Session{c: cSess}
- runtime.SetFinalizer(s, func(s *Session) { s.Close() })
- return s, graph, nil
-}
-
// Run the graph with the associated session starting with the supplied inputs.
// inputs and outputs may be set to nil. Runs, but does not return Tensors
// for operations specified in targets.
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
index ccd7d85295..14ecca402b 100644
--- a/tensorflow/go/session_test.go
+++ b/tensorflow/go/session_test.go
@@ -181,15 +181,3 @@ func TestConcurrency(t *testing.T) {
t.Errorf("Close() 2: %v", err)
}
}
-
-func TestSavedModel(t *testing.T) {
- _, graph, err := LoadSavedModel("../cc/saved_model/testdata/half_plus_two/00000123", []string{"serve"}, nil)
- if err != nil {
- t.Fatalf("LoadSavedModel(): %v", err)
- }
- if op := graph.Operation("y"); op == nil {
- t.Fatalf("\"y\" not found in graph")
- }
- // TODO(jhseu): half_plus_two has a tf.Example proto dependency to run. Add a
- // more thorough test when the generated protobufs are available.
-}