aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/graph_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-09-23 10:25:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-23 11:33:41 -0700
commit7bffed6ba12a4d463a0dfb1978e2108343c63eaf (patch)
tree292e668023177bb637d1fae08f1905b843f3a32e /tensorflow/go/graph_test.go
parentf982313d3c32cf2c7018eae9ec49759ae111d407 (diff)
go: Ability to import a pre-defined Graph.
With this change, it should be possible to execute a pre-defined Graph created by any means (like a training session in a Python program) in Go. One more step towards #10 Change: 134096795
Diffstat (limited to 'tensorflow/go/graph_test.go')
-rw-r--r--tensorflow/go/graph_test.go64
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go
new file mode 100644
index 0000000000..43f80ff4eb
--- /dev/null
+++ b/tensorflow/go/graph_test.go
@@ -0,0 +1,64 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+)
+
+func hasOperations(g *Graph, ops ...string) error {
+ var missing []string
+ for _, op := range ops {
+ if g.Operation(op) == nil {
+ missing = append(missing, op)
+ }
+ }
+ if len(missing) == 0 {
+ return nil
+ }
+ return fmt.Errorf("Graph does not have the operations %v", missing)
+}
+
+func TestGraphWriteToAndImport(t *testing.T) {
+ // Construct a graph
+ g := NewGraph()
+ v, err := NewTensor(int64(1))
+ if err != nil {
+ t.Fatal(err)
+ }
+ input, err := Placeholder(g, "input", v.DataType())
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := Neg(g, "neg", input); err != nil {
+ t.Fatal(err)
+ }
+
+ // Serialize the graph
+ buf := new(bytes.Buffer)
+ if _, err := g.WriteTo(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ // Import it into the same graph, with a prefix
+ if err := g.Import(buf.Bytes(), "imported"); err != nil {
+ t.Error(err)
+ }
+ if err := hasOperations(g, "input", "neg", "imported/input", "imported/neg"); err != nil {
+ t.Error(err)
+ }
+}