diff options
author | Asim Shankar <ashankar@google.com> | 2016-10-07 09:56:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-07 11:04:39 -0700 |
commit | b5b1671163c90cf786fe0c65ca6754a85d017847 (patch) | |
tree | 313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/operation.go | |
parent | 318948b28076540841f0c079f12c8f4fdd3d15f1 (diff) |
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph.
I also intend to use this API in a code generator that will generate
Go sources files containing functions for each OpDef (and all
this generated code will be in a separate package).
While at it, also changed some tests to use the "sub-tests"
feature in Go 1.7 (https://blog.golang.org/subtests)
Another step in the journey of #10
Change: 135493412
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r-- | tensorflow/go/operation.go | 63 |
1 files changed, 18 insertions, 45 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index 59c5c2a2b0..8f4fee1cbe 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -20,7 +20,6 @@ import "C" import ( "errors" - "unsafe" ) // Operation that has been added to the graph. @@ -47,8 +46,6 @@ func (op *Operation) NumOutputs() int { } // Output returns the i-th output of op. -// -// REQUIRES: 0 <= i < op.NumOutputs() func (op *Operation) Output(i int) Output { return Output{op, i} } @@ -97,52 +94,28 @@ func (p Output) Shape() (shape []int64, err error) { return ret, nil } -func (p *Output) c() C.TF_Port { +func (p Output) c() C.TF_Port { return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)} } -// opBuilder is for use by the generated op code to create new Operations. -// Build() must be called for any in-progress Operation, or else we leak. -type opBuilder struct { - c *C.TF_OperationDescription - // A reference to the Graph to prevent it from being GCed while - // the opBuilder is still alive. - g *Graph -} - -func newOpBuilder(g *Graph, typ string, name string) *opBuilder { - opType := C.CString(typ) - opName := C.CString(name) - b := &opBuilder{c: C.TF_NewOperation(g.c, opType, opName), g: g} - C.free(unsafe.Pointer(opType)) - C.free(unsafe.Pointer(opName)) - return b -} - -func (b *opBuilder) SetAttrTensor(name string, t *Tensor) error { - status := newStatus() - attrName := C.CString(name) - C.TF_SetAttrTensor(b.c, attrName, t.c(), status.c) - C.free(unsafe.Pointer(attrName)) - return status.Err() -} +func (p Output) canBeAnInput() {} -func (b *opBuilder) SetAttrType(name string, typ DataType) { - attrName := C.CString(name) - C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ)) - C.free(unsafe.Pointer(attrName)) +// Input is the interface for specifying inputs to an operation being added to +// a Graph. +// +// Operations can have multiple inputs, each of which could be either a tensor +// produced by another operation (an Output object), or a list of tensors +// produced by other operations (an OutputList). Thus, this interface is +// implemented by both Output and OutputList. +// +// See OpSpec.Input for more information. +type Input interface { + // Unexported to preclude implementations outside this package. + canBeAnInput() } -func (b *opBuilder) AddInput(port Output) { - C.TF_AddInput(b.c, port.c()) -} +// OutputList represents a list of Outputs that can be provided as input to +// another operation. +type OutputList []Output -func (b *opBuilder) Build() (*Operation, error) { - status := newStatus() - op := &Operation{c: C.TF_FinishOperation(b.c, status.c), g: b.g} - if err := status.Err(); err != nil { - return nil, err - } - b.c = nil - return op, nil -} +func (l OutputList) canBeAnInput() {} |