diff options
author | Tristan Rice <rice@fn.lc> | 2018-06-11 13:06:04 -0700 |
---|---|---|
committer | Tristan Rice <rice@fn.lc> | 2018-06-12 14:50:39 -0700 |
commit | 8eba32b6c4b259c39097b8b308532b8419d8c151 (patch) | |
tree | c973f7beadd2c82ce4887de88f613c84c20ad069 /tensorflow/go/operation_test.go | |
parent | a4b390bffbcb01d8f57f25c007277d457f752a69 (diff) |
tensorflow/go: add operation Input methods + tests
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r-- | tensorflow/go/operation_test.go | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 40c951ab8c..0672e8ecc7 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -166,6 +166,64 @@ func TestOutputDataTypeAndShape(t *testing.T) { } } +func TestOperationInputs(t *testing.T) { + g := NewGraph() + x, err := Placeholder(g, "x", Float) + if err != nil { + t.Fatal(err) + } + y, err := Placeholder(g, "y", Float) + if err != nil { + t.Fatal(err) + } + add, err := Add(g, "add", x, y) + if err != nil { + t.Fatal(err) + } + addOp := add.Op + + if out := addOp.NumInputs(); out != 2 { + t.Fatalf("Got %d inputs, wanted 2", out) + } +} + +func TestOperationConsumers(t *testing.T) { + g := NewGraph() + x, err := Placeholder(g, "x", Float) + if err != nil { + t.Fatal(err) + } + a, err := Neg(g, "a", x) + if err != nil { + t.Fatal(err) + } + b, err := Neg(g, "b", x) + if err != nil { + t.Fatal(err) + } + + consumers := []*Operation{a.Op, b.Op} + + xConsumers := x.Consumers() + if out := len(xConsumers); out != 2 { + t.Fatalf("Got %d consumers, wanted 2", out) + } + + for i, consumer := range xConsumers { + got := consumer.Op.Name() + want := consumers[i].Name() + if got != want { + t.Fatalf("%d. Got op name %q, wanted %q", i, got, want) + } + + got = consumer.Producer().Op.Name() + want = x.Op.Name() + if got != want { + t.Fatalf("%d. Got op name %q, wanted %q", i, got, want) + } + } +} + func forceGC() { var mem runtime.MemStats runtime.ReadMemStats(&mem) |