1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tensorflow
import (
"reflect"
"testing"
)
func Placeholder(g *Graph, name string, dt DataType) (Port, error) {
b := newOpBuilder(g, "Placeholder", name)
b.SetAttrType("dtype", dt)
op, err := b.Build()
if err != nil {
return Port{}, err
}
return Port{op, 0}, nil
}
func Neg(g *Graph, name string, port Port) (Port, error) {
b := newOpBuilder(g, "Neg", name)
b.AddInput(port)
op, err := b.Build()
if err != nil {
return Port{}, err
}
return Port{op, 0}, nil
}
func createTestGraph(t *testing.T, dt DataType) (*Graph, Port, Port) {
g := NewGraph()
inp, err := Placeholder(g, "p1", dt)
if err != nil {
t.Fatalf("Placeholder() for %v: %v", dt, err)
}
out, err := Neg(g, "neg1", inp)
if err != nil {
t.Fatalf("Neg() for %v: %v", dt, err)
}
return g, inp, out
}
func TestSessionRunNeg(t *testing.T) {
var tests = []struct {
input interface{}
expected interface{}
}{
{int64(1), int64(-1)},
{[]float64{-1, -2, 3}, []float64{1, 2, -3}},
{[][]float32{{1, -2}, {-3, 4}}, [][]float32{{-1, 2}, {3, -4}}},
}
for _, test := range tests {
t1, err := NewTensor(test.input)
if err != nil {
t.Fatalf("NewTensor(%v): %v", test.input, err)
}
graph, inp, out := createTestGraph(t, t1.DataType())
s, err := NewSession(graph, &SessionOptions{})
if err != nil {
t.Fatalf("NewSession() for %v: %v", test.input, err)
}
output, err := s.Run(map[Port]*Tensor{inp: t1}, []Port{out}, []*Operation{out.Op})
if err != nil {
t.Fatalf("Run() for %v: %v", test.input, err)
}
if len(output) != 1 {
t.Errorf("%v: got %d outputs, want 1", test.input, len(output))
continue
}
val := output[0].Value()
if !reflect.DeepEqual(test.expected, val) {
t.Errorf("got %v, want %v", val, test.expected)
}
if err := s.Close(); err != nil {
t.Errorf("Close(): %v", err)
}
}
}
func TestConcurrency(t *testing.T) {
tensor, err := NewTensor(int64(1))
if err != nil {
t.Fatalf("NewTensor(): %v", err)
}
graph, inp, out := createTestGraph(t, tensor.DataType())
s, err := NewSession(graph, &SessionOptions{})
if err != nil {
t.Fatalf("NewSession(): %v", err)
}
for i := 0; i < 100; i++ {
// Session may close before Run() starts, so we don't check the error.
go s.Run(map[Port]*Tensor{inp: tensor}, []Port{out}, []*Operation{out.Op})
}
if err = s.Close(); err != nil {
t.Errorf("Close() 1: %v", err)
}
if err = s.Close(); err != nil {
t.Errorf("Close() 2: %v", err)
}
}
|