aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/graph.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-09-23 10:25:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-23 11:33:41 -0700
commit7bffed6ba12a4d463a0dfb1978e2108343c63eaf (patch)
tree292e668023177bb637d1fae08f1905b843f3a32e /tensorflow/go/graph.go
parentf982313d3c32cf2c7018eae9ec49759ae111d407 (diff)
go: Ability to import a pre-defined Graph.
With this change, it should be possible to execute a pre-defined Graph created by any means (like a training session in a Python program) in Go. One more step towards #10 Change: 134096795
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r--tensorflow/go/graph.go75
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 3e43a39817..e119839218 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -15,10 +15,16 @@
package tensorflow
// #include "tensorflow/c/c_api.h"
+//
+// #include <stdlib.h>
+// #include <string.h>
import "C"
import (
+ "fmt"
+ "io"
"runtime"
+ "unsafe"
)
// Graph represents a computation graph. Graphs may be shared between sessions.
@@ -36,3 +42,72 @@ func NewGraph() *Graph {
func (g *Graph) finalizer() {
C.TF_DeleteGraph(g.c)
}
+
+// WriteTo writes out a serialized representation of g to w.
+//
+// Implements the io.WriterTo interface.
+func (g *Graph) WriteTo(w io.Writer) (int64, error) {
+ buf := C.TF_NewBuffer()
+ defer C.TF_DeleteBuffer(buf)
+ status := newStatus()
+ C.TF_GraphToGraphDef(g.c, buf, status.c)
+ if err := status.Err(); err != nil {
+ return 0, err
+ }
+ if buf.length > (1 << 30) {
+ // For very large graphs, the writes can be chunked.
+ // Punt on that for now.
+ return 0, fmt.Errorf("Graph is too large to write out, Graph.WriteTo needs to be updated")
+ }
+ // A []byte slice backed by C memory.
+ // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
+ length := int(buf.length)
+ slice := (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:length:length]
+ n, err := w.Write(slice)
+ return int64(n), err
+}
+
+// Import imports the nodes and edges from a serialized representation of
+// another Graph into g.
+//
+// Names of imported nodes will be prefixed with prefix.
+func (g *Graph) Import(def []byte, prefix string) error {
+ cprefix := C.CString(prefix)
+ defer C.free(unsafe.Pointer(cprefix))
+
+ opts := C.TF_NewImportGraphDefOptions()
+ defer C.TF_DeleteImportGraphDefOptions(opts)
+ C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix)
+
+ buf := C.TF_NewBuffer()
+ defer C.TF_DeleteBuffer(buf)
+ // Would have preferred to use C.CBytes, but that does not play well
+ // with "go vet" till https://github.com/golang/go/issues/17201 is
+ // resolved.
+ buf.length = C.size_t(len(def))
+ buf.data = C.malloc(buf.length)
+ if buf.data == nil {
+ return fmt.Errorf("unable to allocate memory")
+ }
+ defer C.free(buf.data)
+ C.memcpy(buf.data, unsafe.Pointer(&def[0]), buf.length)
+
+ status := newStatus()
+ C.TF_GraphImportGraphDef(g.c, buf, opts, status.c)
+ if err := status.Err(); err != nil {
+ return err
+ }
+ return nil
+}
+
+// Operation returns the Operation named name in the Graph, or nil if no such
+// operation is present.
+func (g *Graph) Operation(name string) *Operation {
+ cname := C.CString(name)
+ defer C.free(unsafe.Pointer(cname))
+ cop := C.TF_GraphOperationByName(g.c, cname)
+ if cop == nil {
+ return nil
+ }
+ return &Operation{cop,g}
+}