diff options
author | 2017-02-16 16:28:34 -0800 | |
---|---|---|
committer | 2017-02-16 16:47:36 -0800 | |
commit | c35e3b523ae5ee4d557b624e24f1be0ee71edb66 (patch) | |
tree | 8b5cc3110bc4e5e77c3fd3274ddd9f5325964a10 /tensorflow/go/session.go | |
parent | 5cdf2afa5276f4d6b97ca7c6812661994d28957b (diff) |
Go: Add PartialRun support.
Change: 147783087
Diffstat (limited to 'tensorflow/go/session.go')
-rw-r--r-- | tensorflow/go/session.go | 214 |
1 files changed, 171 insertions, 43 deletions
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index dd629441ef..ef357cb520 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -59,14 +59,14 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { return s, nil } -// Run the graph with the associated session starting with the supplied inputs. -// inputs and outputs may be set to nil. Runs, but does not return Tensors -// for operations specified in targets. +// Run the graph with the associated session starting with the supplied feeds +// to compute the value of the requested fetches. Runs, but does not return +// Tensors for operations specified in targets. // -// On success, returns the Tensor outputs in the same order as supplied in -// the outputs argument. If outputs is set to nil, the returned Tensor outputs +// On success, returns the fetched Tensors in the same order as supplied in +// the fetches argument. If fetches is set to nil, the returned Tensor fetches // is empty. -func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Operation) ([]*Tensor, error) { +func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { s.mu.Lock() if s.c == nil { s.mu.Unlock() @@ -76,56 +76,126 @@ func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Op s.mu.Unlock() defer s.wg.Done() - var inputPorts []C.TF_Output - var inputValues []*C.TF_Tensor - if inputs != nil { - for port, tensor := range inputs { - inputPorts = append(inputPorts, port.c()) - inputValues = append(inputValues, tensor.c) - } + c := newCRunArgs(feeds, fetches, targets) + status := newStatus() + C.TF_SessionRun(s.c, nil, + ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), + ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), + ptrOperation(c.targets), C.int(len(targets)), + nil, status.c) + if err := status.Err(); err != nil { + return nil, err } + return c.toGo(), nil +} - var outputPorts []C.TF_Output - for _, port := range outputs { - outputPorts = append(outputPorts, port.c()) - } - outputValues := make([]*C.TF_Tensor, len(outputs)) - var cTargets []*C.TF_Operation - for _, target := range targets { - cTargets = append(cTargets, target.c) +// PartialRun enables incremental evaluation of graphs. +// +// PartialRun allows the caller to pause the evaluation of a graph, run +// arbitrary code that depends on the intermediate computation of the graph, +// and then resume graph execution. The results of the arbitrary code can be +// fed into the graph when resuming execution. In contrast, Session.Run +// executes the graph to compute the requested fetches using the provided feeds +// and discards all intermediate state (e.g., value of intermediate tensors) +// when it returns. +// +// For example, consider a graph for unsupervised training of a neural network +// model. PartialRun can be used to pause execution after the forward pass of +// the network, let the caller actuate the output (e.g., play a game, actuate a +// robot etc.), determine the error/loss and then feed this calculated loss +// when resuming the backward pass of the graph. +type PartialRun struct { + session *Session + handle *C.char +} + +// Run resumes execution of the graph to compute the requested fetches and +// targets with the provided feeds. +func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { + var ( + c = newCRunArgs(feeds, fetches, targets) + status = newStatus() + s = pr.session + ) + s.mu.Lock() + if s.c == nil { + s.mu.Unlock() + return nil, errors.New("session is closed") } + s.wg.Add(1) + s.mu.Unlock() + defer s.wg.Done() - status := newStatus() - var inputPortsPtr *C.TF_Output - var inputValuesPtr **C.TF_Tensor - if len(inputPorts) > 0 { - inputPortsPtr = &inputPorts[0] - inputValuesPtr = &inputValues[0] + C.TF_SessionPRun(s.c, pr.handle, + ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), + ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), + ptrOperation(c.targets), C.int(len(targets)), + status.c) + if err := status.Err(); err != nil { + return nil, err } + return c.toGo(), nil +} + +// NewPartialRun sets up the graph for incremental evaluation. +// +// All values of feeds, fetches and targets that may be provided to Run calls +// on the returned PartialRun need to be provided to NewPartialRun. +// +// See documentation for the PartialRun type. +func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) { + var ( + cfeeds = make([]C.TF_Output, len(feeds)) + cfetches = make([]C.TF_Output, len(fetches)) + ctargets = make([]*C.TF_Operation, len(targets)) - var outputPortsPtr *C.TF_Output - var outputValuesPtr **C.TF_Tensor - if len(outputPorts) > 0 { - outputPortsPtr = &outputPorts[0] - outputValuesPtr = &outputValues[0] + pcfeeds *C.TF_Output + pcfetches *C.TF_Output + pctargets **C.TF_Operation + + status = newStatus() + ) + if len(feeds) > 0 { + pcfeeds = &cfeeds[0] + for i, o := range feeds { + cfeeds[i] = o.c() + } + } + if len(fetches) > 0 { + pcfetches = &cfetches[0] + for i, o := range fetches { + cfetches[i] = o.c() + } + } + if len(targets) > 0 { + pctargets = &ctargets[0] + for i, o := range targets { + ctargets[i] = o.c + } } - var cTargetsPtr **C.TF_Operation - if len(cTargets) > 0 { - cTargetsPtr = &cTargets[0] + s.mu.Lock() + if s.c == nil { + s.mu.Unlock() + return nil, errors.New("session is closed") } + s.wg.Add(1) + s.mu.Unlock() + defer s.wg.Done() - C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c) + pr := &PartialRun{session: s} + C.TF_SessionPRunSetup(s.c, + pcfeeds, C.int(len(feeds)), + pcfetches, C.int(len(fetches)), + pctargets, C.int(len(targets)), + &pr.handle, status.c) if err := status.Err(); err != nil { return nil, err } - - tensors := make([]*Tensor, len(outputValues)) - for i, val := range outputValues { - tensors[i] = newTensorFromC(val) - } - - return tensors, nil + runtime.SetFinalizer(pr, func(pr *PartialRun) { + deletePRunHandle(pr.handle) + }) + return pr, nil } // Close a session. This contacts any other processes associated with this @@ -187,3 +257,61 @@ func (o *SessionOptions) c() *C.TF_SessionOptions { C.free(unsafe.Pointer(t)) return opt } + +// cRunArgs translates the arguments to Session.Run and PartialRun.Run into +// values suitable for C library calls. +type cRunArgs struct { + feeds []C.TF_Output + feedTensors []*C.TF_Tensor + fetches []C.TF_Output + fetchTensors []*C.TF_Tensor + targets []*C.TF_Operation +} + +func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs { + c := &cRunArgs{ + fetches: make([]C.TF_Output, len(fetches)), + fetchTensors: make([]*C.TF_Tensor, len(fetches)), + targets: make([]*C.TF_Operation, len(targets)), + } + for o, t := range feeds { + c.feeds = append(c.feeds, o.c()) + c.feedTensors = append(c.feedTensors, t.c) + } + for i, o := range fetches { + c.fetches[i] = o.c() + } + for i, t := range targets { + c.targets[i] = t.c + } + return c +} + +func (c *cRunArgs) toGo() []*Tensor { + ret := make([]*Tensor, len(c.fetchTensors)) + for i, ct := range c.fetchTensors { + ret[i] = newTensorFromC(ct) + } + return ret +} + +func ptrOutput(l []C.TF_Output) *C.TF_Output { + if len(l) == 0 { + return nil + } + return &l[0] +} + +func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor { + if len(l) == 0 { + return nil + } + return &l[0] +} + +func ptrOperation(l []*C.TF_Operation) **C.TF_Operation { + if len(l) == 0 { + return nil + } + return &l[0] +} |