diff options
author | Jonathan Hseu <jhseu@google.com> | 2017-02-14 17:30:02 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-14 17:47:15 -0800 |
commit | dd51f989b8ca738da8a04970857597ed68fa1a15 (patch) | |
tree | 6f7738c4d94c45b958f9d279ddaf28c71077ba57 /tensorflow/go/saved_model.go | |
parent | b5d6636e5d60dc3094285d097264c3d1cf250a53 (diff) |
Go: Add a SavedModel type
Change: 147543652
Diffstat (limited to 'tensorflow/go/saved_model.go')
-rw-r--r-- | tensorflow/go/saved_model.go | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/tensorflow/go/saved_model.go b/tensorflow/go/saved_model.go new file mode 100644 index 0000000000..9bffa61765 --- /dev/null +++ b/tensorflow/go/saved_model.go @@ -0,0 +1,68 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package tensorflow + +import ( + "runtime" + "unsafe" +) + +// #include <stdlib.h> +// #include "tensorflow/c/c_api.h" +import "C" + +// SavedModel represents the contents of loaded SavedModel. +// TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs. +type SavedModel struct { + Session *Session + Graph *Graph +} + +// LoadSavedModel creates a new SavedModel from a model previously +// exported to a directory on disk. +// +// Exported models contain a set of graphs and, optionally, variable values. +// Tags in the model identify a single graph. LoadSavedModel initializes a +// session with the identified graph and with variables initialized to from the +// checkpoints on disk. +// +// The tensorflow package currently does not have the ability to export a model +// to a directory from Go. This function thus currently targets loading models +// exported in other languages, such as using tf.saved_model.builder in Python. +// See: +// https://www.tensorflow.org/code/tensorflow/python/saved_model/ +func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error) { + status := newStatus() + cOpt := options.c() + cExportDir := C.CString(exportDir) + cTags := make([]*C.char, len(tags)) + for i := range tags { + cTags[i] = C.CString(tags[i]) + } + graph := NewGraph() + // TODO(jhseu): Add support for run_options and meta_graph_def. + cSess := C.TF_LoadSessionFromSavedModel(cOpt, nil, cExportDir, (**C.char)(unsafe.Pointer(&cTags[0])), C.int(len(cTags)), graph.c, nil, status.c) + for i := range cTags { + C.free(unsafe.Pointer(cTags[i])) + } + C.free(unsafe.Pointer(cExportDir)) + C.TF_DeleteSessionOptions(cOpt) + + if err := status.Err(); err != nil { + return nil, err + } + s := &Session{c: cSess} + runtime.SetFinalizer(s, func(s *Session) { s.Close() }) + return &SavedModel{Session: s, Graph: graph}, nil +} |