aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation_test.go
diff options
context:
space:
mode:
authorGravatar Tristan Rice <rice@fn.lc>2018-06-11 13:06:04 -0700
committerGravatar Tristan Rice <rice@fn.lc>2018-06-12 14:50:39 -0700
commit8eba32b6c4b259c39097b8b308532b8419d8c151 (patch)
treec973f7beadd2c82ce4887de88f613c84c20ad069 /tensorflow/go/operation_test.go
parenta4b390bffbcb01d8f57f25c007277d457f752a69 (diff)
tensorflow/go: add operation Input methods + tests
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r--tensorflow/go/operation_test.go58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 40c951ab8c..0672e8ecc7 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -166,6 +166,64 @@ func TestOutputDataTypeAndShape(t *testing.T) {
}
}
+func TestOperationInputs(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ y, err := Placeholder(g, "y", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ add, err := Add(g, "add", x, y)
+ if err != nil {
+ t.Fatal(err)
+ }
+ addOp := add.Op
+
+ if out := addOp.NumInputs(); out != 2 {
+ t.Fatalf("Got %d inputs, wanted 2", out)
+ }
+}
+
+func TestOperationConsumers(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ a, err := Neg(g, "a", x)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := Neg(g, "b", x)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ consumers := []*Operation{a.Op, b.Op}
+
+ xConsumers := x.Consumers()
+ if out := len(xConsumers); out != 2 {
+ t.Fatalf("Got %d consumers, wanted 2", out)
+ }
+
+ for i, consumer := range xConsumers {
+ got := consumer.Op.Name()
+ want := consumers[i].Name()
+ if got != want {
+ t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
+ }
+
+ got = consumer.Producer().Op.Name()
+ want = x.Op.Name()
+ if got != want {
+ t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
+ }
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)