aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/graph_test.go
diff options
context:
space:
mode:
authorGravatar Vishvananda Ishaya Abrams <vishvananda@gmail.com>2017-11-29 22:25:37 -0800
committerGravatar Vishvananda Ishaya Abrams <vishvananda@gmail.com>2017-11-30 09:18:12 -0800
commit5d52b95279be57076a794c2f334c150a26566360 (patch)
tree3bfdded6fb76b2c7f5ab12b0ca595a722ecd278a /tensorflow/go/graph_test.go
parent8a4d84969130162ee001fa52bac51e730129399b (diff)
Adds Operations() method to Graph
There is currently no way to list all of the operations in a graph from the go api. This patch ads an Operations() method to retrieve the list using the existing TF_GraphNextOperation c api. The graph_test was modified to include testing this new method. Signed-off-by: Vishvananda Ishaya Abrams <vishvananda@gmail.com>
Diffstat (limited to 'tensorflow/go/graph_test.go')
-rw-r--r--tensorflow/go/graph_test.go22
1 files changed, 19 insertions, 3 deletions
diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go
index c3120bc720..b8d65c54f6 100644
--- a/tensorflow/go/graph_test.go
+++ b/tensorflow/go/graph_test.go
@@ -29,10 +29,26 @@ func hasOperations(g *Graph, ops ...string) error {
missing = append(missing, op)
}
}
- if len(missing) == 0 {
- return nil
+ if len(missing) != 0 {
+ return fmt.Errorf("Graph does not have the operations %v", missing)
}
- return fmt.Errorf("Graph does not have the operations %v", missing)
+
+ inList := map[string]bool{}
+ for _, op := range g.Operations() {
+ inList[op.Name()] = true
+ }
+
+ for _, op := range ops {
+ if !inList[op] {
+ missing = append(missing, op)
+ }
+ }
+
+ if len(missing) != 0 {
+ return fmt.Errorf("Operations %v are missing from graph.Operations()", missing)
+ }
+
+ return nil
}
func TestGraphWriteToAndImport(t *testing.T) {