aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-14 12:33:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 13:47:09 -0700
commit47318618b24d493877668750f84d114179120ca5 (patch)
treea9735bdfa7831f43100882e4927db85798b7093a /tensorflow/go/operation_test.go
parent33970acb0a94fbf3f3ef0f56733fd6d6a4af4a01 (diff)
go: Add Operation.OutputListSize
This will be needed for generating the function wrappers for ops that produce a list of tensors as output. Another step towards #10 Change: 136191993
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r--tensorflow/go/operation_test.go30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 1f04528e1b..8080515ee9 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -51,6 +51,36 @@ func TestOperationLifetime(t *testing.T) {
}
}
+func TestOperationOutputListSize(t *testing.T) {
+ graph := NewGraph()
+ c1, err := Const(graph, "c1", int64(1))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c2, err := Const(graph, "c2", [][]int64{{1, 2}, {3, 4}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ // The ShapeN op takes a list of tensors as input and a list as output.
+ op, err := graph.AddOperation(OpSpec{
+ Type: "ShapeN",
+ Input: []Input{OutputList{c1, c2}},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ n, err := op.OutputListSize("output")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := n, 2; got != want {
+ t.Errorf("Got %d, want %d", got, want)
+ }
+ if got, want := op.NumOutputs(), 2; got != want {
+ t.Errorf("Got %d, want %d", got, want)
+ }
+}
+
func TestOutputShape(t *testing.T) {
graph := NewGraph()
testdata := []struct {