diff options
author | Asim Shankar <ashankar@google.com> | 2016-09-23 10:25:02 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-23 11:33:41 -0700 |
commit | 7bffed6ba12a4d463a0dfb1978e2108343c63eaf (patch) | |
tree | 292e668023177bb637d1fae08f1905b843f3a32e /tensorflow/go/graph.go | |
parent | f982313d3c32cf2c7018eae9ec49759ae111d407 (diff) |
go: Ability to import a pre-defined Graph.
With this change, it should be possible to execute
a pre-defined Graph created by any means (like a
training session in a Python program) in Go.
One more step towards #10
Change: 134096795
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r-- | tensorflow/go/graph.go | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 3e43a39817..e119839218 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -15,10 +15,16 @@ package tensorflow // #include "tensorflow/c/c_api.h" +// +// #include <stdlib.h> +// #include <string.h> import "C" import ( + "fmt" + "io" "runtime" + "unsafe" ) // Graph represents a computation graph. Graphs may be shared between sessions. @@ -36,3 +42,72 @@ func NewGraph() *Graph { func (g *Graph) finalizer() { C.TF_DeleteGraph(g.c) } + +// WriteTo writes out a serialized representation of g to w. +// +// Implements the io.WriterTo interface. +func (g *Graph) WriteTo(w io.Writer) (int64, error) { + buf := C.TF_NewBuffer() + defer C.TF_DeleteBuffer(buf) + status := newStatus() + C.TF_GraphToGraphDef(g.c, buf, status.c) + if err := status.Err(); err != nil { + return 0, err + } + if buf.length > (1 << 30) { + // For very large graphs, the writes can be chunked. + // Punt on that for now. + return 0, fmt.Errorf("Graph is too large to write out, Graph.WriteTo needs to be updated") + } + // A []byte slice backed by C memory. + // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices + length := int(buf.length) + slice := (*[1 << 30]byte)(unsafe.Pointer(buf.data))[:length:length] + n, err := w.Write(slice) + return int64(n), err +} + +// Import imports the nodes and edges from a serialized representation of +// another Graph into g. +// +// Names of imported nodes will be prefixed with prefix. +func (g *Graph) Import(def []byte, prefix string) error { + cprefix := C.CString(prefix) + defer C.free(unsafe.Pointer(cprefix)) + + opts := C.TF_NewImportGraphDefOptions() + defer C.TF_DeleteImportGraphDefOptions(opts) + C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix) + + buf := C.TF_NewBuffer() + defer C.TF_DeleteBuffer(buf) + // Would have preferred to use C.CBytes, but that does not play well + // with "go vet" till https://github.com/golang/go/issues/17201 is + // resolved. + buf.length = C.size_t(len(def)) + buf.data = C.malloc(buf.length) + if buf.data == nil { + return fmt.Errorf("unable to allocate memory") + } + defer C.free(buf.data) + C.memcpy(buf.data, unsafe.Pointer(&def[0]), buf.length) + + status := newStatus() + C.TF_GraphImportGraphDef(g.c, buf, opts, status.c) + if err := status.Err(); err != nil { + return err + } + return nil +} + +// Operation returns the Operation named name in the Graph, or nil if no such +// operation is present. +func (g *Graph) Operation(name string) *Operation { + cname := C.CString(name) + defer C.free(unsafe.Pointer(cname)) + cop := C.TF_GraphOperationByName(g.c, cname) + if cop == nil { + return nil + } + return &Operation{cop,g} +} |