diff options
-rw-r--r-- | tensorflow/go/graph.go | 14 | ||||
-rw-r--r-- | tensorflow/go/graph_test.go | 22 |
2 files changed, 33 insertions, 3 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index fb0af87acc..fc087d9d99 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -133,6 +133,20 @@ func (g *Graph) Operation(name string) *Operation { return &Operation{cop, g} } +// Operations returns a list of all operations in the graph +func (g *Graph) Operations() []Operation { + var pos C.size_t = 0 + ops := []Operation{} + for { + cop := C.TF_GraphNextOperation(g.c, &pos) + if cop == nil { + break + } + ops = append(ops, Operation{cop, g}) + } + return ops +} + // OpSpec is the specification of an Operation to be added to a Graph // (using Graph.AddOperation). type OpSpec struct { diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go index c3120bc720..b8d65c54f6 100644 --- a/tensorflow/go/graph_test.go +++ b/tensorflow/go/graph_test.go @@ -29,10 +29,26 @@ func hasOperations(g *Graph, ops ...string) error { missing = append(missing, op) } } - if len(missing) == 0 { - return nil + if len(missing) != 0 { + return fmt.Errorf("Graph does not have the operations %v", missing) } - return fmt.Errorf("Graph does not have the operations %v", missing) + + inList := map[string]bool{} + for _, op := range g.Operations() { + inList[op.Name()] = true + } + + for _, op := range ops { + if !inList[op] { + missing = append(missing, op) + } + } + + if len(missing) != 0 { + return fmt.Errorf("Operations %v are missing from graph.Operations()", missing) + } + + return nil } func TestGraphWriteToAndImport(t *testing.T) { |