aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/session.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-02-28 03:01:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-28 03:09:47 -0800
commit49a4ebbf3cb307c513653427b32f30ad35855094 (patch)
tree2b3889a2f8d0bdeb441b654384f5ec61de5b51e3 /tensorflow/go/session.go
parent5a31e9c8bd73265aa76a6ba70e780fcf432b2abf (diff)
Go: Provide a mechanism to configure the Session.
A Session is configured using the ConfigProto protocol buffer. For now, continuing with attempts to keep the 'tensorflow' go package free of any protocol buffer dependencies, SessionOptions uses a serialized representation of this message. This choice might make sense to revisit. Change: 148750535
Diffstat (limited to 'tensorflow/go/session.go')
-rw-r--r--tensorflow/go/session.go39
1 files changed, 33 insertions, 6 deletions
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
index ef357cb520..5a6e1e37ad 100644
--- a/tensorflow/go/session.go
+++ b/tensorflow/go/session.go
@@ -20,6 +20,7 @@ import "C"
import (
"errors"
+ "fmt"
"runtime"
"sync"
"unsafe"
@@ -47,9 +48,12 @@ type Session struct {
// options may be nil to use the default options.
func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
status := newStatus()
- cOpt := options.c()
+ cOpt, doneOpt, err := options.c()
+ defer doneOpt()
+ if err != nil {
+ return nil, err
+ }
cSess := C.TF_NewSession(graph.c, cOpt, status.c)
- C.TF_DeleteSessionOptions(cOpt)
if err := status.Err(); err != nil {
return nil, err
}
@@ -243,19 +247,42 @@ type SessionOptions struct {
// If the session disconnects from the remote process during its
// lifetime, session calls may fail immediately.
Target string
+
+ // Config is a binary-serialized representation of the
+ // tensorflow.ConfigProto protocol message
+ // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
+ Config []byte
}
// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must
-// deallocate by calling C.TF_DeleteSessionOptions().
-func (o *SessionOptions) c() *C.TF_SessionOptions {
+// deallocate by calling the returned done() closure.
+func (o *SessionOptions) c() (ret *C.TF_SessionOptions, done func(), err error) {
opt := C.TF_NewSessionOptions()
if o == nil {
- return opt
+ return opt, func() { C.TF_DeleteSessionOptions(opt) }, nil
}
t := C.CString(o.Target)
C.TF_SetTarget(opt, t)
C.free(unsafe.Pointer(t))
- return opt
+
+ var cConfig unsafe.Pointer
+ if sz := len(o.Config); sz > 0 {
+ status := newStatus()
+ // Copying into C-memory is the simplest thing to do in terms
+ // of memory safety and cgo rules ("C code may not keep a copy
+ // of a Go pointer after the call returns" from
+ // https://golang.org/cmd/cgo/#hdr-Passing_pointers).
+ cConfig = C.CBytes(o.Config)
+ C.TF_SetConfig(opt, cConfig, C.size_t(sz), status.c)
+ if err := status.Err(); err != nil {
+ C.TF_DeleteSessionOptions(opt)
+ return nil, func() {}, fmt.Errorf("invalid SessionOptions.Config: %v", err)
+ }
+ }
+ return opt, func() {
+ C.TF_DeleteSessionOptions(opt)
+ C.free(cConfig)
+ }, nil
}
// cRunArgs translates the arguments to Session.Run and PartialRun.Run into