diff options
author | 2017-02-09 09:31:08 -0800 | |
---|---|---|
committer | 2017-02-09 09:47:35 -0800 | |
commit | edc4dd62615b7e1ee7906f35099ed2e0a7d0a8ef (patch) | |
tree | fe65a41b23d9aedef5c68be5873b2665eec0d39c /tensorflow/go | |
parent | 78be42e00ec29e507edbf92014709dc1b7ee6a38 (diff) |
Go: Add Output.DataType()
Change: 147043532
Diffstat (limited to 'tensorflow/go')
-rw-r--r-- | tensorflow/go/operation.go | 5 | ||||
-rw-r--r-- | tensorflow/go/operation_test.go | 11 |
2 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index df41c40a2b..9c035e5e18 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -75,6 +75,11 @@ type Output struct { Index int } +// DataType returns the type of elements in the tensor produced by p. +func (p Output) DataType() DataType { + return DataType(C.TF_OperationOutputType(p.c())) +} + // Shape returns the (possibly incomplete) shape of the tensor produced p. func (p Output) Shape() Shape { status := newStatus() diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 32e2989179..a5e36f6683 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -96,19 +96,22 @@ func TestOperationShapeAttribute(t *testing.T) { // If and when the API to get attributes is added, check that here. } -func TestOutputShape(t *testing.T) { +func TestOutputDataTypeAndShape(t *testing.T) { graph := NewGraph() testdata := []struct { Value interface{} Shape []int64 + dtype DataType }{ { // Scalar int64(0), []int64{}, + Int64, }, { // Vector - []int64{1, 2, 3}, + []int32{1, 2, 3}, []int64{3}, + Int32, }, { // Matrix [][]float64{ @@ -116,6 +119,7 @@ func TestOutputShape(t *testing.T) { {4, 5, 6}, }, []int64{2, 3}, + Double, }, } for idx, test := range testdata { @@ -124,6 +128,9 @@ func TestOutputShape(t *testing.T) { if err != nil { t.Fatal(err) } + if got, want := c.DataType(), test.dtype; got != want { + t.Errorf("Got DataType %v, want %v", got, want) + } shape := c.Shape() if got, want := shape.NumDimensions(), len(test.Shape); got != want { t.Fatalf("Got a shape with %d dimensions, want %d", got, want) |