aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/session_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-02-16 16:28:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-16 16:47:36 -0800
commitc35e3b523ae5ee4d557b624e24f1be0ee71edb66 (patch)
tree8b5cc3110bc4e5e77c3fd3274ddd9f5325964a10 /tensorflow/go/session_test.go
parent5cdf2afa5276f4d6b97ca7c6812661994d28957b (diff)
Go: Add PartialRun support.
Change: 147783087
Diffstat (limited to 'tensorflow/go/session_test.go')
-rw-r--r--tensorflow/go/session_test.go65
1 files changed, 65 insertions, 0 deletions
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
index 14ecca402b..9afa2be3b4 100644
--- a/tensorflow/go/session_test.go
+++ b/tensorflow/go/session_test.go
@@ -181,3 +181,68 @@ func TestConcurrency(t *testing.T) {
t.Errorf("Close() 2: %v", err)
}
}
+
+func ExamplePartialRun() {
+ var (
+ // Create a graph: a + 2 + 3 + b.
+ //
+ // Skipping error handling for brevity of this example.
+ // The 'op' package can be used to make graph construction code
+ // with error handling more succinct.
+ g = NewGraph()
+ a, _ = Placeholder(g, "a", Int32)
+ b, _ = Placeholder(g, "b", Int32)
+ two, _ = Const(g, "Two", int32(2))
+ three, _ = Const(g, "Three", int32(3))
+
+ plus2, _ = Add(g, "plus2", a, two) // a + 2
+ plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3
+ plusB, _ = Add(g, "plusB", plus3, b) // ((a + 2) + 3) + b
+
+ )
+ sess, err := NewSession(g, nil)
+ if err != nil {
+ panic(err)
+ }
+ defer sess.Close()
+
+ // All the feeds, fetches and targets for subsequent PartialRun.Run
+ // calls must be provided at setup.
+ pr, err := sess.NewPartialRun(
+ []Output{a, b},
+ []Output{plus2, plusB},
+ []*Operation{plus3.Op},
+ )
+ if err != nil {
+ panic(err)
+ }
+
+ // Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'.
+ // Imagine this to be the forward pass of unsupervised neural network
+ // training of a robot.
+ val, _ := NewTensor(int32(1))
+ fetches, err := pr.Run(
+ map[Output]*Tensor{a: val},
+ []Output{plus2},
+ nil)
+ if err != nil {
+ panic(err)
+ }
+ v1 := fetches[0].Value().(int32)
+
+ // Now, feed 'b=4', fetch 'plusB=a+2+3+b'
+ // Imagine this to be the result of actuating the robot to determine
+ // the error produced by the current state of the neural network.
+ val, _ = NewTensor(int32(4))
+ fetches, err = pr.Run(
+ map[Output]*Tensor{b: val},
+ []Output{plusB},
+ nil)
+ if err != nil {
+ panic(err)
+ }
+ v2 := fetches[0].Value().(int32)
+
+ fmt.Println(v1, v2)
+ // Output: 3 10
+}