diff options
author | Asim Shankar <ashankar@google.com> | 2017-02-16 16:28:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-16 16:47:36 -0800 |
commit | c35e3b523ae5ee4d557b624e24f1be0ee71edb66 (patch) | |
tree | 8b5cc3110bc4e5e77c3fd3274ddd9f5325964a10 | |
parent | 5cdf2afa5276f4d6b97ca7c6812661994d28957b (diff) |
Go: Add PartialRun support.
Change: 147783087
-rw-r--r-- | tensorflow/go/lib.go | 10 | ||||
-rw-r--r-- | tensorflow/go/session.cpp | 24 | ||||
-rw-r--r-- | tensorflow/go/session.go | 214 | ||||
-rw-r--r-- | tensorflow/go/session_test.go | 65 | ||||
-rw-r--r-- | tensorflow/go/util_test.go | 13 |
5 files changed, 281 insertions, 45 deletions
diff --git a/tensorflow/go/lib.go b/tensorflow/go/lib.go index dcab7a90f8..7f96c7809a 100644 --- a/tensorflow/go/lib.go +++ b/tensorflow/go/lib.go @@ -16,4 +16,14 @@ package tensorflow // #cgo LDFLAGS: -ltensorflow // #cgo CFLAGS: -I${SRCDIR}/../../ +// +// // TODO(ashankar): Remove this after TensorFlow 1.1 has been released. +// // Till then, the TensorFlow C API binary releases do not contain +// // the TF_DeletePRunHandle symbol. We work around that by +// // implementing the equivalent in session.cpp +// extern void tfDeletePRunHandle(const char*); import "C" + +func deletePRunHandle(h *C.char) { + C.tfDeletePRunHandle(h) +} diff --git a/tensorflow/go/session.cpp b/tensorflow/go/session.cpp new file mode 100644 index 0000000000..9f6fd1f341 --- /dev/null +++ b/tensorflow/go/session.cpp @@ -0,0 +1,24 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO(ashankar): Remove this file when TensorFlow 1.1 is released. +// See lib.go for details. + +extern "C" { +extern void tfDeletePRunHandle(const char* h); +} + +void tfDeletePRunHandle(const char* h) { + delete[] h; +} diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go index dd629441ef..ef357cb520 100644 --- a/tensorflow/go/session.go +++ b/tensorflow/go/session.go @@ -59,14 +59,14 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { return s, nil } -// Run the graph with the associated session starting with the supplied inputs. -// inputs and outputs may be set to nil. Runs, but does not return Tensors -// for operations specified in targets. +// Run the graph with the associated session starting with the supplied feeds +// to compute the value of the requested fetches. Runs, but does not return +// Tensors for operations specified in targets. // -// On success, returns the Tensor outputs in the same order as supplied in -// the outputs argument. If outputs is set to nil, the returned Tensor outputs +// On success, returns the fetched Tensors in the same order as supplied in +// the fetches argument. If fetches is set to nil, the returned Tensor fetches // is empty. -func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Operation) ([]*Tensor, error) { +func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { s.mu.Lock() if s.c == nil { s.mu.Unlock() @@ -76,56 +76,126 @@ func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Op s.mu.Unlock() defer s.wg.Done() - var inputPorts []C.TF_Output - var inputValues []*C.TF_Tensor - if inputs != nil { - for port, tensor := range inputs { - inputPorts = append(inputPorts, port.c()) - inputValues = append(inputValues, tensor.c) - } + c := newCRunArgs(feeds, fetches, targets) + status := newStatus() + C.TF_SessionRun(s.c, nil, + ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), + ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), + ptrOperation(c.targets), C.int(len(targets)), + nil, status.c) + if err := status.Err(); err != nil { + return nil, err } + return c.toGo(), nil +} - var outputPorts []C.TF_Output - for _, port := range outputs { - outputPorts = append(outputPorts, port.c()) - } - outputValues := make([]*C.TF_Tensor, len(outputs)) - var cTargets []*C.TF_Operation - for _, target := range targets { - cTargets = append(cTargets, target.c) +// PartialRun enables incremental evaluation of graphs. +// +// PartialRun allows the caller to pause the evaluation of a graph, run +// arbitrary code that depends on the intermediate computation of the graph, +// and then resume graph execution. The results of the arbitrary code can be +// fed into the graph when resuming execution. In contrast, Session.Run +// executes the graph to compute the requested fetches using the provided feeds +// and discards all intermediate state (e.g., value of intermediate tensors) +// when it returns. +// +// For example, consider a graph for unsupervised training of a neural network +// model. PartialRun can be used to pause execution after the forward pass of +// the network, let the caller actuate the output (e.g., play a game, actuate a +// robot etc.), determine the error/loss and then feed this calculated loss +// when resuming the backward pass of the graph. +type PartialRun struct { + session *Session + handle *C.char +} + +// Run resumes execution of the graph to compute the requested fetches and +// targets with the provided feeds. +func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) { + var ( + c = newCRunArgs(feeds, fetches, targets) + status = newStatus() + s = pr.session + ) + s.mu.Lock() + if s.c == nil { + s.mu.Unlock() + return nil, errors.New("session is closed") } + s.wg.Add(1) + s.mu.Unlock() + defer s.wg.Done() - status := newStatus() - var inputPortsPtr *C.TF_Output - var inputValuesPtr **C.TF_Tensor - if len(inputPorts) > 0 { - inputPortsPtr = &inputPorts[0] - inputValuesPtr = &inputValues[0] + C.TF_SessionPRun(s.c, pr.handle, + ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)), + ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)), + ptrOperation(c.targets), C.int(len(targets)), + status.c) + if err := status.Err(); err != nil { + return nil, err } + return c.toGo(), nil +} + +// NewPartialRun sets up the graph for incremental evaluation. +// +// All values of feeds, fetches and targets that may be provided to Run calls +// on the returned PartialRun need to be provided to NewPartialRun. +// +// See documentation for the PartialRun type. +func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) { + var ( + cfeeds = make([]C.TF_Output, len(feeds)) + cfetches = make([]C.TF_Output, len(fetches)) + ctargets = make([]*C.TF_Operation, len(targets)) - var outputPortsPtr *C.TF_Output - var outputValuesPtr **C.TF_Tensor - if len(outputPorts) > 0 { - outputPortsPtr = &outputPorts[0] - outputValuesPtr = &outputValues[0] + pcfeeds *C.TF_Output + pcfetches *C.TF_Output + pctargets **C.TF_Operation + + status = newStatus() + ) + if len(feeds) > 0 { + pcfeeds = &cfeeds[0] + for i, o := range feeds { + cfeeds[i] = o.c() + } + } + if len(fetches) > 0 { + pcfetches = &cfetches[0] + for i, o := range fetches { + cfetches[i] = o.c() + } + } + if len(targets) > 0 { + pctargets = &ctargets[0] + for i, o := range targets { + ctargets[i] = o.c + } } - var cTargetsPtr **C.TF_Operation - if len(cTargets) > 0 { - cTargetsPtr = &cTargets[0] + s.mu.Lock() + if s.c == nil { + s.mu.Unlock() + return nil, errors.New("session is closed") } + s.wg.Add(1) + s.mu.Unlock() + defer s.wg.Done() - C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c) + pr := &PartialRun{session: s} + C.TF_SessionPRunSetup(s.c, + pcfeeds, C.int(len(feeds)), + pcfetches, C.int(len(fetches)), + pctargets, C.int(len(targets)), + &pr.handle, status.c) if err := status.Err(); err != nil { return nil, err } - - tensors := make([]*Tensor, len(outputValues)) - for i, val := range outputValues { - tensors[i] = newTensorFromC(val) - } - - return tensors, nil + runtime.SetFinalizer(pr, func(pr *PartialRun) { + deletePRunHandle(pr.handle) + }) + return pr, nil } // Close a session. This contacts any other processes associated with this @@ -187,3 +257,61 @@ func (o *SessionOptions) c() *C.TF_SessionOptions { C.free(unsafe.Pointer(t)) return opt } + +// cRunArgs translates the arguments to Session.Run and PartialRun.Run into +// values suitable for C library calls. +type cRunArgs struct { + feeds []C.TF_Output + feedTensors []*C.TF_Tensor + fetches []C.TF_Output + fetchTensors []*C.TF_Tensor + targets []*C.TF_Operation +} + +func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs { + c := &cRunArgs{ + fetches: make([]C.TF_Output, len(fetches)), + fetchTensors: make([]*C.TF_Tensor, len(fetches)), + targets: make([]*C.TF_Operation, len(targets)), + } + for o, t := range feeds { + c.feeds = append(c.feeds, o.c()) + c.feedTensors = append(c.feedTensors, t.c) + } + for i, o := range fetches { + c.fetches[i] = o.c() + } + for i, t := range targets { + c.targets[i] = t.c + } + return c +} + +func (c *cRunArgs) toGo() []*Tensor { + ret := make([]*Tensor, len(c.fetchTensors)) + for i, ct := range c.fetchTensors { + ret[i] = newTensorFromC(ct) + } + return ret +} + +func ptrOutput(l []C.TF_Output) *C.TF_Output { + if len(l) == 0 { + return nil + } + return &l[0] +} + +func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor { + if len(l) == 0 { + return nil + } + return &l[0] +} + +func ptrOperation(l []*C.TF_Operation) **C.TF_Operation { + if len(l) == 0 { + return nil + } + return &l[0] +} diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 14ecca402b..9afa2be3b4 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -181,3 +181,68 @@ func TestConcurrency(t *testing.T) { t.Errorf("Close() 2: %v", err) } } + +func ExamplePartialRun() { + var ( + // Create a graph: a + 2 + 3 + b. + // + // Skipping error handling for brevity of this example. + // The 'op' package can be used to make graph construction code + // with error handling more succinct. + g = NewGraph() + a, _ = Placeholder(g, "a", Int32) + b, _ = Placeholder(g, "b", Int32) + two, _ = Const(g, "Two", int32(2)) + three, _ = Const(g, "Three", int32(3)) + + plus2, _ = Add(g, "plus2", a, two) // a + 2 + plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3 + plusB, _ = Add(g, "plusB", plus3, b) // ((a + 2) + 3) + b + + ) + sess, err := NewSession(g, nil) + if err != nil { + panic(err) + } + defer sess.Close() + + // All the feeds, fetches and targets for subsequent PartialRun.Run + // calls must be provided at setup. + pr, err := sess.NewPartialRun( + []Output{a, b}, + []Output{plus2, plusB}, + []*Operation{plus3.Op}, + ) + if err != nil { + panic(err) + } + + // Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'. + // Imagine this to be the forward pass of unsupervised neural network + // training of a robot. + val, _ := NewTensor(int32(1)) + fetches, err := pr.Run( + map[Output]*Tensor{a: val}, + []Output{plus2}, + nil) + if err != nil { + panic(err) + } + v1 := fetches[0].Value().(int32) + + // Now, feed 'b=4', fetch 'plusB=a+2+3+b' + // Imagine this to be the result of actuating the robot to determine + // the error produced by the current state of the neural network. + val, _ = NewTensor(int32(4)) + fetches, err = pr.Run( + map[Output]*Tensor{b: val}, + []Output{plusB}, + nil) + if err != nil { + panic(err) + } + v2 := fetches[0].Value().(int32) + + fmt.Println(v1, v2) + // Output: 3 10 +} diff --git a/tensorflow/go/util_test.go b/tensorflow/go/util_test.go index 8ab365c656..492c3b1e8b 100644 --- a/tensorflow/go/util_test.go +++ b/tensorflow/go/util_test.go @@ -46,9 +46,18 @@ func Const(g *Graph, name string, value interface{}) (Output, error) { func Neg(g *Graph, name string, port Output) (Output, error) { op, err := g.AddOperation(OpSpec{ - Type: "Neg", - Name: name, + Type: "Neg", + Name: name, Input: []Input{port}, }) return op.Output(0), err } + +func Add(g *Graph, name string, x, y Output) (Output, error) { + op, err := g.AddOperation(OpSpec{ + Type: "Add", + Name: name, + Input: []Input{x, y}, + }) + return op.Output(0), err +} |