diff options
author | Asim Shankar <ashankar@google.com> | 2017-02-09 09:31:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-09 09:47:35 -0800 |
commit | edc4dd62615b7e1ee7906f35099ed2e0a7d0a8ef (patch) | |
tree | fe65a41b23d9aedef5c68be5873b2665eec0d39c /tensorflow/go/operation_test.go | |
parent | 78be42e00ec29e507edbf92014709dc1b7ee6a38 (diff) |
Go: Add Output.DataType()
Change: 147043532
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r-- | tensorflow/go/operation_test.go | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 32e2989179..a5e36f6683 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -96,19 +96,22 @@ func TestOperationShapeAttribute(t *testing.T) { // If and when the API to get attributes is added, check that here. } -func TestOutputShape(t *testing.T) { +func TestOutputDataTypeAndShape(t *testing.T) { graph := NewGraph() testdata := []struct { Value interface{} Shape []int64 + dtype DataType }{ { // Scalar int64(0), []int64{}, + Int64, }, { // Vector - []int64{1, 2, 3}, + []int32{1, 2, 3}, []int64{3}, + Int32, }, { // Matrix [][]float64{ @@ -116,6 +119,7 @@ func TestOutputShape(t *testing.T) { {4, 5, 6}, }, []int64{2, 3}, + Double, }, } for idx, test := range testdata { @@ -124,6 +128,9 @@ func TestOutputShape(t *testing.T) { if err != nil { t.Fatal(err) } + if got, want := c.DataType(), test.dtype; got != want { + t.Errorf("Got DataType %v, want %v", got, want) + } shape := c.Shape() if got, want := shape.NumDimensions(), len(test.Shape); got != want { t.Fatalf("Got a shape with %d dimensions, want %d", got, want) |