aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/session_test.go
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2016-08-23 09:01:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 10:04:53 -0700
commit783c52edeb3c676937dbb97ed0d40958015050d6 (patch)
tree80c74954f68dad26a6e76a1c0edcb979d4d1804c /tensorflow/go/session_test.go
parent096069687c52e16eaa18c1db6e7bbf2737639257 (diff)
Initial version of the Go API. The API is subject to change.
Remaining work to do: - Generated ops. - Generated protocol buffers. - A few calls requiring protocol buffers aren't in this change. Change: 131066649
Diffstat (limited to 'tensorflow/go/session_test.go')
-rw-r--r--tensorflow/go/session_test.go114
1 files changed, 114 insertions, 0 deletions
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
new file mode 100644
index 0000000000..78f6bccfd6
--- /dev/null
+++ b/tensorflow/go/session_test.go
@@ -0,0 +1,114 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+import (
+ "reflect"
+ "testing"
+)
+
+func Placeholder(g *Graph, name string, dt DataType) (Port, error) {
+ b := newOpBuilder(g, "Placeholder", name)
+ b.SetAttrType("dtype", dt)
+ op, err := b.Build()
+ if err != nil {
+ return Port{}, err
+ }
+ return Port{op, 0}, nil
+}
+
+func Neg(g *Graph, name string, port Port) (Port, error) {
+ b := newOpBuilder(g, "Neg", name)
+ b.AddInput(port)
+ op, err := b.Build()
+ if err != nil {
+ return Port{}, err
+ }
+ return Port{op, 0}, nil
+}
+
+func createTestGraph(t *testing.T, dt DataType) (*Graph, Port, Port) {
+ g := NewGraph()
+ inp, err := Placeholder(g, "p1", dt)
+ if err != nil {
+ t.Fatalf("Placeholder() for %v: %v", dt, err)
+ }
+ out, err := Neg(g, "neg1", inp)
+ if err != nil {
+ t.Fatalf("Neg() for %v: %v", dt, err)
+ }
+ return g, inp, out
+}
+
+func TestSessionRunNeg(t *testing.T) {
+ var tests = []struct {
+ input interface{}
+ expected interface{}
+ }{
+ {int64(1), int64(-1)},
+ {[]float64{-1, -2, 3}, []float64{1, 2, -3}},
+ {[][]float32{{1, -2}, {-3, 4}}, [][]float32{{-1, 2}, {3, -4}}},
+ }
+
+ for _, test := range tests {
+ t1, err := NewTensor(test.input)
+ if err != nil {
+ t.Fatalf("NewTensor(%v): %v", test.input, err)
+ }
+ graph, inp, out := createTestGraph(t, t1.DataType())
+ s, err := NewSession(graph, &SessionOptions{})
+ if err != nil {
+ t.Fatalf("NewSession() for %v: %v", test.input, err)
+ }
+ output, err := s.Run(map[Port]*Tensor{inp: t1}, []Port{out}, []*Operation{out.Op})
+ if err != nil {
+ t.Fatalf("Run() for %v: %v", test.input, err)
+ }
+ if len(output) != 1 {
+ t.Errorf("%v: got %d outputs, want 1", test.input, len(output))
+ continue
+ }
+ val := output[0].Value()
+ if !reflect.DeepEqual(test.expected, val) {
+ t.Errorf("got %v, want %v", val, test.expected)
+ }
+ if err := s.Close(); err != nil {
+ t.Errorf("Close(): %v", err)
+ }
+ }
+}
+
+func TestConcurrency(t *testing.T) {
+ tensor, err := NewTensor(int64(1))
+ if err != nil {
+ t.Fatalf("NewTensor(): %v", err)
+ }
+
+ graph, inp, out := createTestGraph(t, tensor.DataType())
+ s, err := NewSession(graph, &SessionOptions{})
+ if err != nil {
+ t.Fatalf("NewSession(): %v", err)
+ }
+ for i := 0; i < 100; i++ {
+ // Session may close before Run() starts, so we don't check the error.
+ go s.Run(map[Port]*Tensor{inp: tensor}, []Port{out}, []*Operation{out.Op})
+ }
+ if err = s.Close(); err != nil {
+ t.Errorf("Close() 1: %v", err)
+ }
+ if err = s.Close(); err != nil {
+ t.Errorf("Close() 2: %v", err)
+ }
+}