aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-07 09:56:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 11:04:39 -0700
commitb5b1671163c90cf786fe0c65ca6754a85d017847 (patch)
tree313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/operation_test.go
parent318948b28076540841f0c079f12c8f4fdd3d15f1 (diff)
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph. I also intend to use this API in a code generator that will generate Go sources files containing functions for each OpDef (and all this generated code will be in a separate package). While at it, also changed some tests to use the "sub-tests" feature in Go 1.7 (https://blog.golang.org/subtests) Another step in the journey of #10 Change: 135493412
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r--tensorflow/go/operation_test.go40
1 files changed, 17 insertions, 23 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 35a5d95357..1f04528e1b 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -74,30 +74,24 @@ func TestOutputShape(t *testing.T) {
},
}
for idx, test := range testdata {
- tensor, err := NewTensor(test.Value)
- if err != nil {
- t.Errorf("#%d: NewTensor(%T) failed: %v", idx, test.Value, err)
- continue
- }
- c, err := Const(graph, fmt.Sprintf("test%d", idx), tensor)
- if err != nil {
- t.Errorf("#%d: Const(%T) failed: %v", idx, test.Value, err)
- continue
- }
- shape, err := c.Shape()
- if err != nil {
- t.Errorf("#%d: Shape() failed for %T: %v", idx, test.Value, err)
- continue
- }
- if got, want := len(shape), len(test.Shape); got != want {
- t.Errorf("#%d: %T: Got a shape with %d dimensions, want %d", idx, test.Value, got, want)
- continue
- }
- for i := 0; i < len(shape); i++ {
- if got, want := shape[i], test.Shape[i]; got != want {
- t.Errorf("#%d: %T: Got %d, want %d for dimension #%d/%d", idx, test.Value, got, want, i, len(shape))
+ t.Run(fmt.Sprintf("#%d Value %T", idx, test.Value), func(t *testing.T) {
+ c, err := Const(graph, fmt.Sprintf("const%d", idx), test.Value)
+ if err != nil {
+ t.Fatal(err)
}
- }
+ shape, err := c.Shape()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := len(shape), len(test.Shape); got != want {
+ t.Fatalf("Got a shape with %d dimensions, want %d", got, want)
+ }
+ for i := 0; i < len(shape); i++ {
+ if got, want := shape[i], test.Shape[i]; got != want {
+ t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(shape))
+ }
+ }
+ })
}
// Unknown number of dimensions
dummyTensor, err := NewTensor(float64(0))