diff options
author | Jonathan Hseu <jhseu@google.com> | 2016-08-23 09:01:25 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-23 10:04:53 -0700 |
commit | 783c52edeb3c676937dbb97ed0d40958015050d6 (patch) | |
tree | 80c74954f68dad26a6e76a1c0edcb979d4d1804c /tensorflow/go/session_test.go | |
parent | 096069687c52e16eaa18c1db6e7bbf2737639257 (diff) |
Initial version of the Go API. The API is subject to change.
Remaining work to do:
- Generated ops.
- Generated protocol buffers.
- A few calls requiring protocol buffers aren't in this change.
Change: 131066649
Diffstat (limited to 'tensorflow/go/session_test.go')
-rw-r--r-- | tensorflow/go/session_test.go | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go new file mode 100644 index 0000000000..78f6bccfd6 --- /dev/null +++ b/tensorflow/go/session_test.go @@ -0,0 +1,114 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "reflect" + "testing" +) + +func Placeholder(g *Graph, name string, dt DataType) (Port, error) { + b := newOpBuilder(g, "Placeholder", name) + b.SetAttrType("dtype", dt) + op, err := b.Build() + if err != nil { + return Port{}, err + } + return Port{op, 0}, nil +} + +func Neg(g *Graph, name string, port Port) (Port, error) { + b := newOpBuilder(g, "Neg", name) + b.AddInput(port) + op, err := b.Build() + if err != nil { + return Port{}, err + } + return Port{op, 0}, nil +} + +func createTestGraph(t *testing.T, dt DataType) (*Graph, Port, Port) { + g := NewGraph() + inp, err := Placeholder(g, "p1", dt) + if err != nil { + t.Fatalf("Placeholder() for %v: %v", dt, err) + } + out, err := Neg(g, "neg1", inp) + if err != nil { + t.Fatalf("Neg() for %v: %v", dt, err) + } + return g, inp, out +} + +func TestSessionRunNeg(t *testing.T) { + var tests = []struct { + input interface{} + expected interface{} + }{ + {int64(1), int64(-1)}, + {[]float64{-1, -2, 3}, []float64{1, 2, -3}}, + {[][]float32{{1, -2}, {-3, 4}}, [][]float32{{-1, 2}, {3, -4}}}, + } + + for _, test := range tests { + t1, err := NewTensor(test.input) + if err != nil { + t.Fatalf("NewTensor(%v): %v", test.input, err) + } + graph, inp, out := createTestGraph(t, t1.DataType()) + s, err := NewSession(graph, &SessionOptions{}) + if err != nil { + t.Fatalf("NewSession() for %v: %v", test.input, err) + } + output, err := s.Run(map[Port]*Tensor{inp: t1}, []Port{out}, []*Operation{out.Op}) + if err != nil { + t.Fatalf("Run() for %v: %v", test.input, err) + } + if len(output) != 1 { + t.Errorf("%v: got %d outputs, want 1", test.input, len(output)) + continue + } + val := output[0].Value() + if !reflect.DeepEqual(test.expected, val) { + t.Errorf("got %v, want %v", val, test.expected) + } + if err := s.Close(); err != nil { + t.Errorf("Close(): %v", err) + } + } +} + +func TestConcurrency(t *testing.T) { + tensor, err := NewTensor(int64(1)) + if err != nil { + t.Fatalf("NewTensor(): %v", err) + } + + graph, inp, out := createTestGraph(t, tensor.DataType()) + s, err := NewSession(graph, &SessionOptions{}) + if err != nil { + t.Fatalf("NewSession(): %v", err) + } + for i := 0; i < 100; i++ { + // Session may close before Run() starts, so we don't check the error. + go s.Run(map[Port]*Tensor{inp: tensor}, []Port{out}, []*Operation{out.Op}) + } + if err = s.Close(); err != nil { + t.Errorf("Close() 1: %v", err) + } + if err = s.Close(); err != nil { + t.Errorf("Close() 2: %v", err) + } +} |