aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-02-16 16:28:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-16 16:47:36 -0800
commitc35e3b523ae5ee4d557b624e24f1be0ee71edb66 (patch)
tree8b5cc3110bc4e5e77c3fd3274ddd9f5325964a10
parent5cdf2afa5276f4d6b97ca7c6812661994d28957b (diff)
Go: Add PartialRun support.
Change: 147783087
-rw-r--r--tensorflow/go/lib.go10
-rw-r--r--tensorflow/go/session.cpp24
-rw-r--r--tensorflow/go/session.go214
-rw-r--r--tensorflow/go/session_test.go65
-rw-r--r--tensorflow/go/util_test.go13
5 files changed, 281 insertions, 45 deletions
diff --git a/tensorflow/go/lib.go b/tensorflow/go/lib.go
index dcab7a90f8..7f96c7809a 100644
--- a/tensorflow/go/lib.go
+++ b/tensorflow/go/lib.go
@@ -16,4 +16,14 @@ package tensorflow
// #cgo LDFLAGS: -ltensorflow
// #cgo CFLAGS: -I${SRCDIR}/../../
+//
+// // TODO(ashankar): Remove this after TensorFlow 1.1 has been released.
+// // Till then, the TensorFlow C API binary releases do not contain
+// // the TF_DeletePRunHandle symbol. We work around that by
+// // implementing the equivalent in session.cpp
+// extern void tfDeletePRunHandle(const char*);
import "C"
+
+func deletePRunHandle(h *C.char) {
+ C.tfDeletePRunHandle(h)
+}
diff --git a/tensorflow/go/session.cpp b/tensorflow/go/session.cpp
new file mode 100644
index 0000000000..9f6fd1f341
--- /dev/null
+++ b/tensorflow/go/session.cpp
@@ -0,0 +1,24 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// TODO(ashankar): Remove this file when TensorFlow 1.1 is released.
+// See lib.go for details.
+
+extern "C" {
+extern void tfDeletePRunHandle(const char* h);
+}
+
+void tfDeletePRunHandle(const char* h) {
+ delete[] h;
+}
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
index dd629441ef..ef357cb520 100644
--- a/tensorflow/go/session.go
+++ b/tensorflow/go/session.go
@@ -59,14 +59,14 @@ func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
return s, nil
}
-// Run the graph with the associated session starting with the supplied inputs.
-// inputs and outputs may be set to nil. Runs, but does not return Tensors
-// for operations specified in targets.
+// Run the graph with the associated session starting with the supplied feeds
+// to compute the value of the requested fetches. Runs, but does not return
+// Tensors for operations specified in targets.
//
-// On success, returns the Tensor outputs in the same order as supplied in
-// the outputs argument. If outputs is set to nil, the returned Tensor outputs
+// On success, returns the fetched Tensors in the same order as supplied in
+// the fetches argument. If fetches is set to nil, the returned Tensor fetches
// is empty.
-func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Operation) ([]*Tensor, error) {
+func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
s.mu.Lock()
if s.c == nil {
s.mu.Unlock()
@@ -76,56 +76,126 @@ func (s *Session) Run(inputs map[Output]*Tensor, outputs []Output, targets []*Op
s.mu.Unlock()
defer s.wg.Done()
- var inputPorts []C.TF_Output
- var inputValues []*C.TF_Tensor
- if inputs != nil {
- for port, tensor := range inputs {
- inputPorts = append(inputPorts, port.c())
- inputValues = append(inputValues, tensor.c)
- }
+ c := newCRunArgs(feeds, fetches, targets)
+ status := newStatus()
+ C.TF_SessionRun(s.c, nil,
+ ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
+ ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
+ ptrOperation(c.targets), C.int(len(targets)),
+ nil, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
}
+ return c.toGo(), nil
+}
- var outputPorts []C.TF_Output
- for _, port := range outputs {
- outputPorts = append(outputPorts, port.c())
- }
- outputValues := make([]*C.TF_Tensor, len(outputs))
- var cTargets []*C.TF_Operation
- for _, target := range targets {
- cTargets = append(cTargets, target.c)
+// PartialRun enables incremental evaluation of graphs.
+//
+// PartialRun allows the caller to pause the evaluation of a graph, run
+// arbitrary code that depends on the intermediate computation of the graph,
+// and then resume graph execution. The results of the arbitrary code can be
+// fed into the graph when resuming execution. In contrast, Session.Run
+// executes the graph to compute the requested fetches using the provided feeds
+// and discards all intermediate state (e.g., value of intermediate tensors)
+// when it returns.
+//
+// For example, consider a graph for unsupervised training of a neural network
+// model. PartialRun can be used to pause execution after the forward pass of
+// the network, let the caller actuate the output (e.g., play a game, actuate a
+// robot etc.), determine the error/loss and then feed this calculated loss
+// when resuming the backward pass of the graph.
+type PartialRun struct {
+ session *Session
+ handle *C.char
+}
+
+// Run resumes execution of the graph to compute the requested fetches and
+// targets with the provided feeds.
+func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error) {
+ var (
+ c = newCRunArgs(feeds, fetches, targets)
+ status = newStatus()
+ s = pr.session
+ )
+ s.mu.Lock()
+ if s.c == nil {
+ s.mu.Unlock()
+ return nil, errors.New("session is closed")
}
+ s.wg.Add(1)
+ s.mu.Unlock()
+ defer s.wg.Done()
- status := newStatus()
- var inputPortsPtr *C.TF_Output
- var inputValuesPtr **C.TF_Tensor
- if len(inputPorts) > 0 {
- inputPortsPtr = &inputPorts[0]
- inputValuesPtr = &inputValues[0]
+ C.TF_SessionPRun(s.c, pr.handle,
+ ptrOutput(c.feeds), ptrTensor(c.feedTensors), C.int(len(feeds)),
+ ptrOutput(c.fetches), ptrTensor(c.fetchTensors), C.int(len(fetches)),
+ ptrOperation(c.targets), C.int(len(targets)),
+ status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
}
+ return c.toGo(), nil
+}
+
+// NewPartialRun sets up the graph for incremental evaluation.
+//
+// All values of feeds, fetches and targets that may be provided to Run calls
+// on the returned PartialRun need to be provided to NewPartialRun.
+//
+// See documentation for the PartialRun type.
+func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error) {
+ var (
+ cfeeds = make([]C.TF_Output, len(feeds))
+ cfetches = make([]C.TF_Output, len(fetches))
+ ctargets = make([]*C.TF_Operation, len(targets))
- var outputPortsPtr *C.TF_Output
- var outputValuesPtr **C.TF_Tensor
- if len(outputPorts) > 0 {
- outputPortsPtr = &outputPorts[0]
- outputValuesPtr = &outputValues[0]
+ pcfeeds *C.TF_Output
+ pcfetches *C.TF_Output
+ pctargets **C.TF_Operation
+
+ status = newStatus()
+ )
+ if len(feeds) > 0 {
+ pcfeeds = &cfeeds[0]
+ for i, o := range feeds {
+ cfeeds[i] = o.c()
+ }
+ }
+ if len(fetches) > 0 {
+ pcfetches = &cfetches[0]
+ for i, o := range fetches {
+ cfetches[i] = o.c()
+ }
+ }
+ if len(targets) > 0 {
+ pctargets = &ctargets[0]
+ for i, o := range targets {
+ ctargets[i] = o.c
+ }
}
- var cTargetsPtr **C.TF_Operation
- if len(cTargets) > 0 {
- cTargetsPtr = &cTargets[0]
+ s.mu.Lock()
+ if s.c == nil {
+ s.mu.Unlock()
+ return nil, errors.New("session is closed")
}
+ s.wg.Add(1)
+ s.mu.Unlock()
+ defer s.wg.Done()
- C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c)
+ pr := &PartialRun{session: s}
+ C.TF_SessionPRunSetup(s.c,
+ pcfeeds, C.int(len(feeds)),
+ pcfetches, C.int(len(fetches)),
+ pctargets, C.int(len(targets)),
+ &pr.handle, status.c)
if err := status.Err(); err != nil {
return nil, err
}
-
- tensors := make([]*Tensor, len(outputValues))
- for i, val := range outputValues {
- tensors[i] = newTensorFromC(val)
- }
-
- return tensors, nil
+ runtime.SetFinalizer(pr, func(pr *PartialRun) {
+ deletePRunHandle(pr.handle)
+ })
+ return pr, nil
}
// Close a session. This contacts any other processes associated with this
@@ -187,3 +257,61 @@ func (o *SessionOptions) c() *C.TF_SessionOptions {
C.free(unsafe.Pointer(t))
return opt
}
+
+// cRunArgs translates the arguments to Session.Run and PartialRun.Run into
+// values suitable for C library calls.
+type cRunArgs struct {
+ feeds []C.TF_Output
+ feedTensors []*C.TF_Tensor
+ fetches []C.TF_Output
+ fetchTensors []*C.TF_Tensor
+ targets []*C.TF_Operation
+}
+
+func newCRunArgs(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) *cRunArgs {
+ c := &cRunArgs{
+ fetches: make([]C.TF_Output, len(fetches)),
+ fetchTensors: make([]*C.TF_Tensor, len(fetches)),
+ targets: make([]*C.TF_Operation, len(targets)),
+ }
+ for o, t := range feeds {
+ c.feeds = append(c.feeds, o.c())
+ c.feedTensors = append(c.feedTensors, t.c)
+ }
+ for i, o := range fetches {
+ c.fetches[i] = o.c()
+ }
+ for i, t := range targets {
+ c.targets[i] = t.c
+ }
+ return c
+}
+
+func (c *cRunArgs) toGo() []*Tensor {
+ ret := make([]*Tensor, len(c.fetchTensors))
+ for i, ct := range c.fetchTensors {
+ ret[i] = newTensorFromC(ct)
+ }
+ return ret
+}
+
+func ptrOutput(l []C.TF_Output) *C.TF_Output {
+ if len(l) == 0 {
+ return nil
+ }
+ return &l[0]
+}
+
+func ptrTensor(l []*C.TF_Tensor) **C.TF_Tensor {
+ if len(l) == 0 {
+ return nil
+ }
+ return &l[0]
+}
+
+func ptrOperation(l []*C.TF_Operation) **C.TF_Operation {
+ if len(l) == 0 {
+ return nil
+ }
+ return &l[0]
+}
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
index 14ecca402b..9afa2be3b4 100644
--- a/tensorflow/go/session_test.go
+++ b/tensorflow/go/session_test.go
@@ -181,3 +181,68 @@ func TestConcurrency(t *testing.T) {
t.Errorf("Close() 2: %v", err)
}
}
+
+func ExamplePartialRun() {
+ var (
+ // Create a graph: a + 2 + 3 + b.
+ //
+ // Skipping error handling for brevity of this example.
+ // The 'op' package can be used to make graph construction code
+ // with error handling more succinct.
+ g = NewGraph()
+ a, _ = Placeholder(g, "a", Int32)
+ b, _ = Placeholder(g, "b", Int32)
+ two, _ = Const(g, "Two", int32(2))
+ three, _ = Const(g, "Three", int32(3))
+
+ plus2, _ = Add(g, "plus2", a, two) // a + 2
+ plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3
+ plusB, _ = Add(g, "plusB", plus3, b) // ((a + 2) + 3) + b
+
+ )
+ sess, err := NewSession(g, nil)
+ if err != nil {
+ panic(err)
+ }
+ defer sess.Close()
+
+ // All the feeds, fetches and targets for subsequent PartialRun.Run
+ // calls must be provided at setup.
+ pr, err := sess.NewPartialRun(
+ []Output{a, b},
+ []Output{plus2, plusB},
+ []*Operation{plus3.Op},
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ // Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'.
+ // Imagine this to be the forward pass of unsupervised neural network
+ // training of a robot.
+ val, _ := NewTensor(int32(1))
+ fetches, err := pr.Run(
+ map[Output]*Tensor{a: val},
+ []Output{plus2},
+ nil)
+ if err != nil {
+ panic(err)
+ }
+ v1 := fetches[0].Value().(int32)
+
+ // Now, feed 'b=4', fetch 'plusB=a+2+3+b'
+ // Imagine this to be the result of actuating the robot to determine
+ // the error produced by the current state of the neural network.
+ val, _ = NewTensor(int32(4))
+ fetches, err = pr.Run(
+ map[Output]*Tensor{b: val},
+ []Output{plusB},
+ nil)
+ if err != nil {
+ panic(err)
+ }
+ v2 := fetches[0].Value().(int32)
+
+ fmt.Println(v1, v2)
+ // Output: 3 10
+}
diff --git a/tensorflow/go/util_test.go b/tensorflow/go/util_test.go
index 8ab365c656..492c3b1e8b 100644
--- a/tensorflow/go/util_test.go
+++ b/tensorflow/go/util_test.go
@@ -46,9 +46,18 @@ func Const(g *Graph, name string, value interface{}) (Output, error) {
func Neg(g *Graph, name string, port Output) (Output, error) {
op, err := g.AddOperation(OpSpec{
- Type: "Neg",
- Name: name,
+ Type: "Neg",
+ Name: name,
Input: []Input{port},
})
return op.Output(0), err
}
+
+func Add(g *Graph, name string, x, y Output) (Output, error) {
+ op, err := g.AddOperation(OpSpec{
+ Type: "Add",
+ Name: name,
+ Input: []Input{x, y},
+ })
+ return op.Output(0), err
+}