diff options
author | 2017-11-29 22:25:37 -0800 | |
---|---|---|
committer | 2017-11-30 09:18:12 -0800 | |
commit | 5d52b95279be57076a794c2f334c150a26566360 (patch) | |
tree | 3bfdded6fb76b2c7f5ab12b0ca595a722ecd278a /tensorflow/go/graph_test.go | |
parent | 8a4d84969130162ee001fa52bac51e730129399b (diff) |
Adds Operations() method to Graph
There is currently no way to list all of the operations in a graph
from the go api. This patch ads an Operations() method to retrieve the
list using the existing TF_GraphNextOperation c api. The graph_test
was modified to include testing this new method.
Signed-off-by: Vishvananda Ishaya Abrams <vishvananda@gmail.com>
Diffstat (limited to 'tensorflow/go/graph_test.go')
-rw-r--r-- | tensorflow/go/graph_test.go | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go index c3120bc720..b8d65c54f6 100644 --- a/tensorflow/go/graph_test.go +++ b/tensorflow/go/graph_test.go @@ -29,10 +29,26 @@ func hasOperations(g *Graph, ops ...string) error { missing = append(missing, op) } } - if len(missing) == 0 { - return nil + if len(missing) != 0 { + return fmt.Errorf("Graph does not have the operations %v", missing) } - return fmt.Errorf("Graph does not have the operations %v", missing) + + inList := map[string]bool{} + for _, op := range g.Operations() { + inList[op.Name()] = true + } + + for _, op := range ops { + if !inList[op] { + missing = append(missing, op) + } + } + + if len(missing) != 0 { + return fmt.Errorf("Operations %v are missing from graph.Operations()", missing) + } + + return nil } func TestGraphWriteToAndImport(t *testing.T) { |