aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-09-28 10:37:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-28 11:48:45 -0700
commit7cacfdf03872c36cfc3f43e75e8c342234160e3a (patch)
tree0856b1decb0a1f36a68450c9a5d9b825de999c06 /tensorflow/go/operation_test.go
parentb410f52f6616e281d8274317938a258605ad993a (diff)
go: Add an example.
Add an example (that will appear in Go doc) for a real use of the Go TensorFlow APIs - using a pre-defined image recognition model for inference. While at it a couple of minor tweaks: - NewSession now accepts 'nil' for options as the documentation says it does - Convenience accessors for the Outputs of an Operation - Ability to extract (possibly partial) shapes from an Output Another step towards #10 Change: 134560938
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r--tensorflow/go/operation_test.go63
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 7cba0a5c27..35a5d95357 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -15,6 +15,7 @@
package tensorflow
import (
+ "fmt"
"runtime"
"runtime/debug"
"testing"
@@ -50,6 +51,68 @@ func TestOperationLifetime(t *testing.T) {
}
}
+func TestOutputShape(t *testing.T) {
+ graph := NewGraph()
+ testdata := []struct {
+ Value interface{}
+ Shape []int64
+ }{
+ { // Scalar
+ int64(0),
+ []int64{},
+ },
+ { // Vector
+ []int64{1, 2, 3},
+ []int64{3},
+ },
+ { // Matrix
+ [][]float64{
+ {1, 2, 3},
+ {4, 5, 6},
+ },
+ []int64{2, 3},
+ },
+ }
+ for idx, test := range testdata {
+ tensor, err := NewTensor(test.Value)
+ if err != nil {
+ t.Errorf("#%d: NewTensor(%T) failed: %v", idx, test.Value, err)
+ continue
+ }
+ c, err := Const(graph, fmt.Sprintf("test%d", idx), tensor)
+ if err != nil {
+ t.Errorf("#%d: Const(%T) failed: %v", idx, test.Value, err)
+ continue
+ }
+ shape, err := c.Shape()
+ if err != nil {
+ t.Errorf("#%d: Shape() failed for %T: %v", idx, test.Value, err)
+ continue
+ }
+ if got, want := len(shape), len(test.Shape); got != want {
+ t.Errorf("#%d: %T: Got a shape with %d dimensions, want %d", idx, test.Value, got, want)
+ continue
+ }
+ for i := 0; i < len(shape); i++ {
+ if got, want := shape[i], test.Shape[i]; got != want {
+ t.Errorf("#%d: %T: Got %d, want %d for dimension #%d/%d", idx, test.Value, got, want, i, len(shape))
+ }
+ }
+ }
+ // Unknown number of dimensions
+ dummyTensor, err := NewTensor(float64(0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ placeholder, err := Placeholder(graph, "placeholder", dummyTensor.DataType())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if shape, err := placeholder.Shape(); err == nil {
+ t.Errorf("Got shape %v, wanted error", shape)
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)