aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/util_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-07 09:56:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 11:04:39 -0700
commitb5b1671163c90cf786fe0c65ca6754a85d017847 (patch)
tree313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/util_test.go
parent318948b28076540841f0c079f12c8f4fdd3d15f1 (diff)
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph. I also intend to use this API in a code generator that will generate Go sources files containing functions for each OpDef (and all this generated code will be in a separate package). While at it, also changed some tests to use the "sub-tests" feature in Go 1.7 (https://blog.golang.org/subtests) Another step in the journey of #10 Change: 135493412
Diffstat (limited to 'tensorflow/go/util_test.go')
-rw-r--r--tensorflow/go/util_test.go45
1 files changed, 31 insertions, 14 deletions
diff --git a/tensorflow/go/util_test.go b/tensorflow/go/util_test.go
index 06b4e61d0f..8ab365c656 100644
--- a/tensorflow/go/util_test.go
+++ b/tensorflow/go/util_test.go
@@ -15,23 +15,40 @@
package tensorflow
func Placeholder(g *Graph, name string, dt DataType) (Output, error) {
- b := newOpBuilder(g, "Placeholder", name)
- b.SetAttrType("dtype", dt)
- op, err := b.Build()
- return Output{op, 0}, err
+ op, err := g.AddOperation(OpSpec{
+ Type: "Placeholder",
+ Name: name,
+ Attrs: map[string]interface{}{
+ "dtype": dt,
+ },
+ })
+ return op.Output(0), err
}
-func Const(g *Graph, name string, t *Tensor) (Output, error) {
- b := newOpBuilder(g, "Const", name)
- b.SetAttrType("dtype", t.DataType())
- b.SetAttrTensor("value", t)
- op, err := b.Build()
- return Output{op, 0}, err
+func Const(g *Graph, name string, value interface{}) (Output, error) {
+ t, ok := value.(*Tensor)
+ if !ok {
+ var err error
+ if t, err = NewTensor(value); err != nil {
+ return Output{}, err
+ }
+ }
+ op, err := g.AddOperation(OpSpec{
+ Type: "Const",
+ Name: name,
+ Attrs: map[string]interface{}{
+ "dtype": t.DataType(),
+ "value": t,
+ },
+ })
+ return op.Output(0), err
}
func Neg(g *Graph, name string, port Output) (Output, error) {
- b := newOpBuilder(g, "Neg", name)
- b.AddInput(port)
- op, err := b.Build()
- return Output{op, 0}, err
+ op, err := g.AddOperation(OpSpec{
+ Type: "Neg",
+ Name: name,
+ Input: []Input{port},
+ })
+ return op.Output(0), err
}