aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/go/graph.go14
-rw-r--r--tensorflow/go/graph_test.go22
2 files changed, 33 insertions, 3 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index fb0af87acc..fc087d9d99 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -133,6 +133,20 @@ func (g *Graph) Operation(name string) *Operation {
return &Operation{cop, g}
}
+// Operations returns a list of all operations in the graph
+func (g *Graph) Operations() []Operation {
+ var pos C.size_t = 0
+ ops := []Operation{}
+ for {
+ cop := C.TF_GraphNextOperation(g.c, &pos)
+ if cop == nil {
+ break
+ }
+ ops = append(ops, Operation{cop, g})
+ }
+ return ops
+}
+
// OpSpec is the specification of an Operation to be added to a Graph
// (using Graph.AddOperation).
type OpSpec struct {
diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go
index c3120bc720..b8d65c54f6 100644
--- a/tensorflow/go/graph_test.go
+++ b/tensorflow/go/graph_test.go
@@ -29,10 +29,26 @@ func hasOperations(g *Graph, ops ...string) error {
missing = append(missing, op)
}
}
- if len(missing) == 0 {
- return nil
+ if len(missing) != 0 {
+ return fmt.Errorf("Graph does not have the operations %v", missing)
}
- return fmt.Errorf("Graph does not have the operations %v", missing)
+
+ inList := map[string]bool{}
+ for _, op := range g.Operations() {
+ inList[op.Name()] = true
+ }
+
+ for _, op := range ops {
+ if !inList[op] {
+ missing = append(missing, op)
+ }
+ }
+
+ if len(missing) != 0 {
+ return fmt.Errorf("Operations %v are missing from graph.Operations()", missing)
+ }
+
+ return nil
}
func TestGraphWriteToAndImport(t *testing.T) {