diff options
author | Asim Shankar <ashankar@google.com> | 2016-10-07 09:56:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-07 11:04:39 -0700 |
commit | b5b1671163c90cf786fe0c65ca6754a85d017847 (patch) | |
tree | 313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/operation_test.go | |
parent | 318948b28076540841f0c079f12c8f4fdd3d15f1 (diff) |
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph.
I also intend to use this API in a code generator that will generate
Go sources files containing functions for each OpDef (and all
this generated code will be in a separate package).
While at it, also changed some tests to use the "sub-tests"
feature in Go 1.7 (https://blog.golang.org/subtests)
Another step in the journey of #10
Change: 135493412
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r-- | tensorflow/go/operation_test.go | 40 |
1 files changed, 17 insertions, 23 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 35a5d95357..1f04528e1b 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -74,30 +74,24 @@ func TestOutputShape(t *testing.T) { }, } for idx, test := range testdata { - tensor, err := NewTensor(test.Value) - if err != nil { - t.Errorf("#%d: NewTensor(%T) failed: %v", idx, test.Value, err) - continue - } - c, err := Const(graph, fmt.Sprintf("test%d", idx), tensor) - if err != nil { - t.Errorf("#%d: Const(%T) failed: %v", idx, test.Value, err) - continue - } - shape, err := c.Shape() - if err != nil { - t.Errorf("#%d: Shape() failed for %T: %v", idx, test.Value, err) - continue - } - if got, want := len(shape), len(test.Shape); got != want { - t.Errorf("#%d: %T: Got a shape with %d dimensions, want %d", idx, test.Value, got, want) - continue - } - for i := 0; i < len(shape); i++ { - if got, want := shape[i], test.Shape[i]; got != want { - t.Errorf("#%d: %T: Got %d, want %d for dimension #%d/%d", idx, test.Value, got, want, i, len(shape)) + t.Run(fmt.Sprintf("#%d Value %T", idx, test.Value), func(t *testing.T) { + c, err := Const(graph, fmt.Sprintf("const%d", idx), test.Value) + if err != nil { + t.Fatal(err) } - } + shape, err := c.Shape() + if err != nil { + t.Fatal(err) + } + if got, want := len(shape), len(test.Shape); got != want { + t.Fatalf("Got a shape with %d dimensions, want %d", got, want) + } + for i := 0; i < len(shape); i++ { + if got, want := shape[i], test.Shape[i]; got != want { + t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(shape)) + } + } + }) } // Unknown number of dimensions dummyTensor, err := NewTensor(float64(0)) |