diff options
author | 2017-02-28 03:01:11 -0800 | |
---|---|---|
committer | 2017-02-28 03:09:47 -0800 | |
commit | 49a4ebbf3cb307c513653427b32f30ad35855094 (patch) | |
tree | 2b3889a2f8d0bdeb441b654384f5ec61de5b51e3 /tensorflow/go/session.go | |
parent | 5a31e9c8bd73265aa76a6ba70e780fcf432b2abf (diff) |
Go: Provide a mechanism to configure the Session.
A Session is configured using the ConfigProto protocol buffer.
For now, continuing with attempts to keep the 'tensorflow' go package
free of any protocol buffer dependencies, SessionOptions uses a serialized
representation of this message. This choice might make sense to revisit.
Change: 148750535
Diffstat (limited to 'tensorflow/go/session.go')
-rw-r--r-- | tensorflow/go/session.go | 39 |
1 files changed, 33 insertions, 6 deletions
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index ef357cb520..5a6e1e37ad 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -20,6 +20,7 @@ import "C" import ( "errors" + "fmt" "runtime" "sync" "unsafe" @@ -47,9 +48,12 @@ type Session struct { // options may be nil to use the default options. func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { status := newStatus() - cOpt := options.c() + cOpt, doneOpt, err := options.c() + defer doneOpt() + if err != nil { + return nil, err + } cSess := C.TF_NewSession(graph.c, cOpt, status.c) - C.TF_DeleteSessionOptions(cOpt) if err := status.Err(); err != nil { return nil, err } @@ -243,19 +247,42 @@ type SessionOptions struct { // If the session disconnects from the remote process during its // lifetime, session calls may fail immediately. Target string + + // Config is a binary-serialized representation of the + // tensorflow.ConfigProto protocol message + // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto). + Config []byte } // c converts the SessionOptions to the C API's TF_SessionOptions. Callers must -// deallocate by calling C.TF_DeleteSessionOptions(). -func (o *SessionOptions) c() *C.TF_SessionOptions { +// deallocate by calling the returned done() closure. +func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) { opt := C.TF_NewSessionOptions() if o == nil { - return opt + return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil } t := C.CString(o.Target) C.TF_SetTarget(opt, t) C.free(unsafe.Pointer(t)) - return opt + + var cConfig unsafe.Pointer + if sz := len(o.Config); sz > 0 { + status := newStatus() + // Copying into C-memory is the simplest thing to do in terms + // of memory safety and cgo rules ("C code may not keep a copy + // of a Go pointer after the call returns" from + // https://golang.org/cmd/cgo/#hdr-Passing_pointers). + cConfig = C.CBytes(o.Config) + C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c) + if err := status.Err(); err != nil { + C.TF_DeleteSessionOptions(opt) + return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err) + } + } + return opt, func() { + C.TF_DeleteSessionOptions(opt) + C.free(cConfig) + }, nil } // cRunArgs translates the arguments to Session.Run and PartialRun.Run into |