diff options
author | Asim Shankar <ashankar@google.com> | 2016-10-07 09:56:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-07 11:04:39 -0700 |
commit | b5b1671163c90cf786fe0c65ca6754a85d017847 (patch) | |
tree | 313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/util_test.go | |
parent | 318948b28076540841f0c079f12c8f4fdd3d15f1 (diff) |
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph.
I also intend to use this API in a code generator that will generate
Go sources files containing functions for each OpDef (and all
this generated code will be in a separate package).
While at it, also changed some tests to use the "sub-tests"
feature in Go 1.7 (https://blog.golang.org/subtests)
Another step in the journey of #10
Change: 135493412
Diffstat (limited to 'tensorflow/go/util_test.go')
-rw-r--r-- | tensorflow/go/util_test.go | 45 |
1 files changed, 31 insertions, 14 deletions
diff --git a/tensorflow/go/util_test.go b/tensorflow/go/util_test.go index 06b4e61d0f..8ab365c656 100644 --- a/tensorflow/go/util_test.go +++ b/tensorflow/go/util_test.go @@ -15,23 +15,40 @@ package tensorflow func Placeholder(g *Graph, name string, dt DataType) (Output, error) { - b := newOpBuilder(g, "Placeholder", name) - b.SetAttrType("dtype", dt) - op, err := b.Build() - return Output{op, 0}, err + op, err := g.AddOperation(OpSpec{ + Type: "Placeholder", + Name: name, + Attrs: map[string]interface{}{ + "dtype": dt, + }, + }) + return op.Output(0), err } -func Const(g *Graph, name string, t *Tensor) (Output, error) { - b := newOpBuilder(g, "Const", name) - b.SetAttrType("dtype", t.DataType()) - b.SetAttrTensor("value", t) - op, err := b.Build() - return Output{op, 0}, err +func Const(g *Graph, name string, value interface{}) (Output, error) { + t, ok := value.(*Tensor) + if !ok { + var err error + if t, err = NewTensor(value); err != nil { + return Output{}, err + } + } + op, err := g.AddOperation(OpSpec{ + Type: "Const", + Name: name, + Attrs: map[string]interface{}{ + "dtype": t.DataType(), + "value": t, + }, + }) + return op.Output(0), err } func Neg(g *Graph, name string, port Output) (Output, error) { - b := newOpBuilder(g, "Neg", name) - b.AddInput(port) - op, err := b.Build() - return Output{op, 0}, err + op, err := g.AddOperation(OpSpec{ + Type: "Neg", + Name: name, + Input: []Input{port}, + }) + return op.Output(0), err } |