diff options
author | Asim Shankar <ashankar@google.com> | 2017-02-16 16:28:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-16 16:47:36 -0800 |
commit | c35e3b523ae5ee4d557b624e24f1be0ee71edb66 (patch) | |
tree | 8b5cc3110bc4e5e77c3fd3274ddd9f5325964a10 /tensorflow/go/session_test.go | |
parent | 5cdf2afa5276f4d6b97ca7c6812661994d28957b (diff) |
Go: Add PartialRun support.
Change: 147783087
Diffstat (limited to 'tensorflow/go/session_test.go')
-rw-r--r-- | tensorflow/go/session_test.go | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 14ecca402b..9afa2be3b4 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -181,3 +181,68 @@ func TestConcurrency(t *testing.T) { t.Errorf("Close() 2: %v", err) } } + +func ExamplePartialRun() { + var ( + // Create a graph: a + 2 + 3 + b. + // + // Skipping error handling for brevity of this example. + // The 'op' package can be used to make graph construction code + // with error handling more succinct. + g = NewGraph() + a, _ = Placeholder(g, "a", Int32) + b, _ = Placeholder(g, "b", Int32) + two, _ = Const(g, "Two", int32(2)) + three, _ = Const(g, "Three", int32(3)) + + plus2, _ = Add(g, "plus2", a, two) // a + 2 + plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3 + plusB, _ = Add(g, "plusB", plus3, b) // ((a + 2) + 3) + b + + ) + sess, err := NewSession(g, nil) + if err != nil { + panic(err) + } + defer sess.Close() + + // All the feeds, fetches and targets for subsequent PartialRun.Run + // calls must be provided at setup. + pr, err := sess.NewPartialRun( + []Output{a, b}, + []Output{plus2, plusB}, + []*Operation{plus3.Op}, + ) + if err != nil { + panic(err) + } + + // Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'. + // Imagine this to be the forward pass of unsupervised neural network + // training of a robot. + val, _ := NewTensor(int32(1)) + fetches, err := pr.Run( + map[Output]*Tensor{a: val}, + []Output{plus2}, + nil) + if err != nil { + panic(err) + } + v1 := fetches[0].Value().(int32) + + // Now, feed 'b=4', fetch 'plusB=a+2+3+b' + // Imagine this to be the result of actuating the robot to determine + // the error produced by the current state of the neural network. + val, _ = NewTensor(int32(4)) + fetches, err = pr.Run( + map[Output]*Tensor{b: val}, + []Output{plusB}, + nil) + if err != nil { + panic(err) + } + v2 := fetches[0].Value().(int32) + + fmt.Println(v1, v2) + // Output: 3 10 +} |