aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-02-09 09:31:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-09 09:47:35 -0800
commitedc4dd62615b7e1ee7906f35099ed2e0a7d0a8ef (patch)
treefe65a41b23d9aedef5c68be5873b2665eec0d39c /tensorflow/go/operation_test.go
parent78be42e00ec29e507edbf92014709dc1b7ee6a38 (diff)
Go: Add Output.DataType()
Change: 147043532
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r--tensorflow/go/operation_test.go11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 32e2989179..a5e36f6683 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -96,19 +96,22 @@ func TestOperationShapeAttribute(t *testing.T) {
// If and when the API to get attributes is added, check that here.
}
-func TestOutputShape(t *testing.T) {
+func TestOutputDataTypeAndShape(t *testing.T) {
graph := NewGraph()
testdata := []struct {
Value interface{}
Shape []int64
+ dtype DataType
}{
{ // Scalar
int64(0),
[]int64{},
+ Int64,
},
{ // Vector
- []int64{1, 2, 3},
+ []int32{1, 2, 3},
[]int64{3},
+ Int32,
},
{ // Matrix
[][]float64{
@@ -116,6 +119,7 @@ func TestOutputShape(t *testing.T) {
{4, 5, 6},
},
[]int64{2, 3},
+ Double,
},
}
for idx, test := range testdata {
@@ -124,6 +128,9 @@ func TestOutputShape(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ if got, want := c.DataType(), test.dtype; got != want {
+ t.Errorf("Got DataType %v, want %v", got, want)
+ }
shape := c.Shape()
if got, want := shape.NumDimensions(), len(test.Shape); got != want {
t.Fatalf("Got a shape with %d dimensions, want %d", got, want)