diff options
author | 2016-09-28 10:37:31 -0800 | |
---|---|---|
committer | 2016-09-28 11:48:45 -0700 | |
commit | 7cacfdf03872c36cfc3f43e75e8c342234160e3a (patch) | |
tree | 0856b1decb0a1f36a68450c9a5d9b825de999c06 /tensorflow/go/operation_test.go | |
parent | b410f52f6616e281d8274317938a258605ad993a (diff) |
go: Add an example.
Add an example (that will appear in Go doc) for a real use of the Go TensorFlow
APIs - using a pre-defined image recognition model for inference.
While at it a couple of minor tweaks:
- NewSession now accepts 'nil' for options as the documentation says it does
- Convenience accessors for the Outputs of an Operation
- Ability to extract (possibly partial) shapes from an Output
Another step towards #10
Change: 134560938
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r-- | tensorflow/go/operation_test.go | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 7cba0a5c27..35a5d95357 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -15,6 +15,7 @@ package tensorflow import ( + "fmt" "runtime" "runtime/debug" "testing" @@ -50,6 +51,68 @@ func TestOperationLifetime(t *testing.T) { } } +func TestOutputShape(t *testing.T) { + graph := NewGraph() + testdata := []struct { + Value interface{} + Shape []int64 + }{ + { // Scalar + int64(0), + []int64{}, + }, + { // Vector + []int64{1, 2, 3}, + []int64{3}, + }, + { // Matrix + [][]float64{ + {1, 2, 3}, + {4, 5, 6}, + }, + []int64{2, 3}, + }, + } + for idx, test := range testdata { + tensor, err := NewTensor(test.Value) + if err != nil { + t.Errorf("#%d: NewTensor(%T) failed: %v", idx, test.Value, err) + continue + } + c, err := Const(graph, fmt.Sprintf("test%d", idx), tensor) + if err != nil { + t.Errorf("#%d: Const(%T) failed: %v", idx, test.Value, err) + continue + } + shape, err := c.Shape() + if err != nil { + t.Errorf("#%d: Shape() failed for %T: %v", idx, test.Value, err) + continue + } + if got, want := len(shape), len(test.Shape); got != want { + t.Errorf("#%d: %T: Got a shape with %d dimensions, want %d", idx, test.Value, got, want) + continue + } + for i := 0; i < len(shape); i++ { + if got, want := shape[i], test.Shape[i]; got != want { + t.Errorf("#%d: %T: Got %d, want %d for dimension #%d/%d", idx, test.Value, got, want, i, len(shape)) + } + } + } + // Unknown number of dimensions + dummyTensor, err := NewTensor(float64(0)) + if err != nil { + t.Fatal(err) + } + placeholder, err := Placeholder(graph, "placeholder", dummyTensor.DataType()) + if err != nil { + t.Fatal(err) + } + if shape, err := placeholder.Shape(); err == nil { + t.Errorf("Got shape %v, wanted error", shape) + } +} + func forceGC() { var mem runtime.MemStats runtime.ReadMemStats(&mem) |