aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/session.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-02-16 16:28:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-16 16:47:36 -0800
commitc35e3b523ae5ee4d557b624e24f1be0ee71edb66 (patch)
tree8b5cc3110bc4e5e77c3fd3274ddd9f5325964a10 /tensorflow/go/session.go
parent5cdf2afa5276f4d6b97ca7c6812661994d28957b (diff)
Go: Add PartialRun support.
Change: 147783087
Diffstat (limited to 'tensorflow/go/session.go')
-rw-r--r--tensorflow/go/session.go214
1 files changed, 171 insertions, 43 deletions
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
index dd629441ef..ef357cb520 100644
--- a/tensorflow/go/session.go
+++ b/tensorflow/go/session.go
@@ -59,14 +59,14 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
return s, nil
}
-// Run the graph with the associated session starting with the supplied inputs.
-// inputs and outputs may be set to nil. Runs, but does not return Tensors
-// for operations specified in targets.
+// Run the graph with the associated session starting with the supplied feeds
+// to compute the value of the requested fetches. Runs, but does not return
+// Tensors for operations specified in targets.
//
-// On success, returns the Tensor outputs in the same order as supplied in
-// the outputs argument. If outputs is set to nil, the returned Tensor outputs
+// On success, returns the fetched Tensors in the same order as supplied in
+// the fetches argument. If fetches is set to nil, the returned Tensor fetches
// is empty.
-func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Operation) ([]*Tensor, error) {
+func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
s.mu.Lock()
if s.c == nil {
s.mu.Unlock()
@@ -76,56 +76,126 @@ func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Op
s.mu.Unlock()
defer s.wg.Done()
- var inputPorts []C.TF_Output
- var inputValues []*C.TF_Tensor
- if inputs != nil {
- for port, tensor := range inputs {
- inputPorts = append(inputPorts, port.c())
- inputValues = append(inputValues, tensor.c)
- }
+ c := newCRunArgs(feeds, fetches, targets)
+ status := newStatus()
+ C.TF_SessionRun(s.c, nil,
+ ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
+ ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
+ ptrOperation(c.targets), C.int(len(targets)),
+ nil, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
}
+ return c.toGo(), nil
+}
- var outputPorts []C.TF_Output
- for _, port := range outputs {
- outputPorts = append(outputPorts, port.c())
- }
- outputValues := make([]*C.TF_Tensor, len(outputs))
- var cTargets []*C.TF_Operation
- for _, target := range targets {
- cTargets = append(cTargets, target.c)
+// PartialRun enables incremental evaluation of graphs.
+//
+// PartialRun allows the caller to pause the evaluation of a graph, run
+// arbitrary code that depends on the intermediate computation of the graph,
+// and then resume graph execution. The results of the arbitrary code can be
+// fed into the graph when resuming execution. In contrast, Session.Run
+// executes the graph to compute the requested fetches using the provided feeds
+// and discards all intermediate state (e.g., value of intermediate tensors)
+// when it returns.
+//
+// For example, consider a graph for unsupervised training of a neural network
+// model. PartialRun can be used to pause execution after the forward pass of
+// the network, let the caller actuate the output (e.g., play a game, actuate a
+// robot etc.), determine the error/loss and then feed this calculated loss
+// when resuming the backward pass of the graph.
+type PartialRun struct {
+ session *Session
+ handle *C.char
+}
+
+// Run resumes execution of the graph to compute the requested fetches and
+// targets with the provided feeds.
+func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
+ var (
+ c = newCRunArgs(feeds, fetches, targets)
+ status = newStatus()
+ s = pr.session
+ )
+ s.mu.Lock()
+ if s.c == nil {
+ s.mu.Unlock()
+ return nil, errors.New("session is closed")
}
+ s.wg.Add(1)
+ s.mu.Unlock()
+ defer s.wg.Done()
- status := newStatus()
- var inputPortsPtr *C.TF_Output
- var inputValuesPtr **C.TF_Tensor
- if len(inputPorts) > 0 {
- inputPortsPtr = &inputPorts[0]
- inputValuesPtr = &inputValues[0]
+ C.TF_SessionPRun(s.c, pr.handle,
+ ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
+ ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
+ ptrOperation(c.targets), C.int(len(targets)),
+ status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
}
+ return c.toGo(), nil
+}
+
+// NewPartialRun sets up the graph for incremental evaluation.
+//
+// All values of feeds, fetches and targets that may be provided to Run calls
+// on the returned PartialRun need to be provided to NewPartialRun.
+//
+// See documentation for the PartialRun type.
+func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) {
+ var (
+ cfeeds = make([]C.TF_Output, len(feeds))
+ cfetches = make([]C.TF_Output, len(fetches))
+ ctargets = make([]*C.TF_Operation, len(targets))
- var outputPortsPtr *C.TF_Output
- var outputValuesPtr **C.TF_Tensor
- if len(outputPorts) > 0 {
- outputPortsPtr = &outputPorts[0]
- outputValuesPtr = &outputValues[0]
+ pcfeeds *C.TF_Output
+ pcfetches *C.TF_Output
+ pctargets **C.TF_Operation
+
+ status = newStatus()
+ )
+ if len(feeds) > 0 {
+ pcfeeds = &cfeeds[0]
+ for i, o := range feeds {
+ cfeeds[i] = o.c()
+ }
+ }
+ if len(fetches) > 0 {
+ pcfetches = &cfetches[0]
+ for i, o := range fetches {
+ cfetches[i] = o.c()
+ }
+ }
+ if len(targets) > 0 {
+ pctargets = &ctargets[0]
+ for i, o := range targets {
+ ctargets[i] = o.c
+ }
}
- var cTargetsPtr **C.TF_Operation
- if len(cTargets) > 0 {
- cTargetsPtr = &cTargets[0]
+ s.mu.Lock()
+ if s.c == nil {
+ s.mu.Unlock()
+ return nil, errors.New("session is closed")
}
+ s.wg.Add(1)
+ s.mu.Unlock()
+ defer s.wg.Done()
- C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c)
+ pr := &PartialRun{session: s}
+ C.TF_SessionPRunSetup(s.c,
+ pcfeeds, C.int(len(feeds)),
+ pcfetches, C.int(len(fetches)),
+ pctargets, C.int(len(targets)),
+ &pr.handle, status.c)
if err := status.Err(); err != nil {
return nil, err
}
-
- tensors := make([]*Tensor, len(outputValues))
- for i, val := range outputValues {
- tensors[i] = newTensorFromC(val)
- }
-
- return tensors, nil
+ runtime.SetFinalizer(pr, func(pr *PartialRun) {
+ deletePRunHandle(pr.handle)
+ })
+ return pr, nil
}
// Close a session. This contacts any other processes associated with this
@@ -187,3 +257,61 @@ func (o *SessionOptions) c() *C.TF_SessionOptions {
C.free(unsafe.Pointer(t))
return opt
}
+
+// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
+// values suitable for C library calls.
+type cRunArgs struct {
+ feeds []C.TF_Output
+ feedTensors []*C.TF_Tensor
+ fetches []C.TF_Output
+ fetchTensors []*C.TF_Tensor
+ targets []*C.TF_Operation
+}
+
+func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs {
+ c := &cRunArgs{
+ fetches: make([]C.TF_Output, len(fetches)),
+ fetchTensors: make([]*C.TF_Tensor, len(fetches)),
+ targets: make([]*C.TF_Operation, len(targets)),
+ }
+ for o, t := range feeds {
+ c.feeds = append(c.feeds, o.c())
+ c.feedTensors = append(c.feedTensors, t.c)
+ }
+ for i, o := range fetches {
+ c.fetches[i] = o.c()
+ }
+ for i, t := range targets {
+ c.targets[i] = t.c
+ }
+ return c
+}
+
+func (c *cRunArgs) toGo() []*Tensor {
+ ret := make([]*Tensor, len(c.fetchTensors))
+ for i, ct := range c.fetchTensors {
+ ret[i] = newTensorFromC(ct)
+ }
+ return ret
+}
+
+func ptrOutput(l []C.TF_Output) *C.TF_Output {
+ if len(l) == 0 {
+ return nil
+ }
+ return &l[0]
+}
+
+func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor {
+ if len(l) == 0 {
+ return nil
+ }
+ return &l[0]
+}
+
+func ptrOperation(l []*C.TF_Operation) **C.TF_Operation {
+ if len(l) == 0 {
+ return nil
+ }
+ return &l[0]
+}